diff options
author | Ting-Wei Lan <lantw44@gmail.com> | 2019-04-29 19:11:19 +0800 |
---|---|---|
committer | Ting-Wei Lan <tingwei.lan@cobinhood.com> | 2019-05-14 11:04:15 +0800 |
commit | 49acaa8391cef7bd069b965d7de827cb0ba683b7 (patch) | |
tree | 5a04654ade25bec384f59dd8220f1de7d68a4338 | |
parent | 77dab86d3a7382e3a0baa79ccd39676b59850518 (diff) | |
download | dexon-49acaa8391cef7bd069b965d7de827cb0ba683b7.tar.gz dexon-49acaa8391cef7bd069b965d7de827cb0ba683b7.tar.zst dexon-49acaa8391cef7bd069b965d7de827cb0ba683b7.zip |
code backup 25
-rw-r--r-- | core/vm/sqlvm/checker/checker.go | 268 | ||||
-rw-r--r-- | core/vm/sqlvm/checker/utils.go | 50 |
2 files changed, 212 insertions, 106 deletions
diff --git a/core/vm/sqlvm/checker/checker.go b/core/vm/sqlvm/checker/checker.go index dce59f332..5dcfaa043 100644 --- a/core/vm/sqlvm/checker/checker.go +++ b/core/vm/sqlvm/checker/checker.go @@ -2052,8 +2052,8 @@ func elAppendTypeErrorOperandDataType(el *errors.ErrorList, n ast.ExprNode, Severity: errors.ErrorSeverityError, Prefix: fn, Message: fmt.Sprintf( - "cannot use %s (%04x) as an operand of %s because there is "+ - "already an operand declared as %s (%04x)", + "cannot use %s (%04x) with %s because the operand is expected "+ + "to be %s (%04x)", dtGiven.String(), uint16(dtGiven), op, dtExpected.String(), uint16(dtExpected)), }, nil) @@ -2070,12 +2070,151 @@ func elAppendTypeErrorOperandValueNode(el *errors.ErrorList, n ast.Valuer, Severity: errors.ErrorSeverityError, Prefix: fn, Message: fmt.Sprintf( - "cannot use %s as an operand of %s because there is "+ - "already an operand found to be %s", + "cannot use %s with %s because the other operand is expected "+ + "to be %s", describeValueNodeType(n), op, describeValueNodeType(nExpected)), }, nil) } +func extractConstantValue(n ast.Valuer) constantValue { + switch n := n.(type) { + case *ast.BoolValueNode: + return newConstantValueBool(n.V) + case *ast.AddressValueNode: + return newConstantValueBytes(n.V) + case *ast.IntegerValueNode: + return newConstantValueDecimal(n.V) + case *ast.DecimalValueNode: + return newConstantValueDecimal(n.V) + case *ast.BytesValueNode: + return newConstantValueBytes(n.V) + case *ast.NullValueNode: + return nil + default: + panic(unknownValueNodeType(n)) + } +} + +func unknownConstantValueType(v constantValue) string { + return fmt.Sprintf("unknown constant value type %T", v) +} + +func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer, + el *errors.ErrorList, fn, op string, + evalBool func(ast.BoolValue, ast.BoolValue) ast.BoolValue, + evalBytes func([]byte, []byte) ast.BoolValue, + evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue, +) *ast.BoolValueNode { + + compatibleTypes := func() bool { + switch object.(type) { + case *ast.BoolValueNode: + switch subject.(type) { + case *ast.BoolValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.AddressValueNode: + switch subject.(type) { + case *ast.AddressValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.IntegerValueNode: + switch subject.(type) { + case *ast.IntegerValueNode: + case *ast.DecimalValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.DecimalValueNode: + switch subject.(type) { + case *ast.IntegerValueNode: + case *ast.DecimalValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.BytesValueNode: + switch subject.(type) { + case *ast.BytesValueNode: + case *ast.NullValueNode: + default: + return false + } + case *ast.NullValueNode: + default: + panic(unknownValueNodeType(object)) + } + return true + } + + if !compatibleTypes() { + elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) + return nil + } + + arg1 := extractConstantValue(object) + arg2 := extractConstantValue(subject) + + var vo ast.BoolValue + switch v1 := arg1.(type) { + case constantValueBool: + var v2 ast.BoolValue + if arg2 == nil { + v2 = ast.BoolValueUnknown + } else { + v2 = arg2.(constantValueBool).GetBool() + } + vo = evalBool(v1.GetBool(), v2) + + case constantValueBytes: + var v2 []byte + if arg2 == nil { + v2 = nil + } else { + v2 = arg2.(constantValueBytes).GetBytes() + } + vo = evalBytes(v1.GetBytes(), v2) + + case constantValueDecimal: + var v2 decimal.NullDecimal + if arg2 == nil { + v2 = decimal.NullDecimal{Valid: false} + } else { + v2 = arg2.(constantValueDecimal).GetDecimal() + } + vo = evalDecimal(v1.GetDecimal(), v2) + + case nil: + switch v2 := arg2.(type) { + case constantValueBool: + vo = evalBool(ast.BoolValueUnknown, v2.GetBool()) + case constantValueBytes: + vo = evalBytes(nil, v2.GetBytes()) + case constantValueDecimal: + vo = evalDecimal(decimal.NullDecimal{Valid: false}, v2.GetDecimal()) + case nil: + vo = evalBool(ast.BoolValueUnknown, ast.BoolValueUnknown) + default: + panic(unknownConstantValueType(v2)) + } + + default: + panic(unknownConstantValueType(v1)) + } + + node := &ast.BoolValueNode{} + node.SetPosition(n.GetPosition()) + node.SetLength(n.GetLength()) + node.SetToken(n.GetToken()) + node.V = vo + return node +} + func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode, s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, tr schema.TableRef, ta typeAction) ast.ExprNode { @@ -2140,112 +2279,29 @@ func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode, panic("unreachable") } - fold := func(object, subject ast.Valuer) bool { - var vo ast.BoolValue - eval: - switch object := object.(type) { - case *ast.BoolValueNode: - var v1, v2 ast.BoolValue - v1 = object.V - switch subject := subject.(type) { - case *ast.BoolValueNode: - v2 = subject.V - case *ast.NullValueNode: - v2 = ast.BoolValueUnknown - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = v1.GreaterOrEqual(v2) - - case *ast.AddressValueNode: - var v1, v2 []byte - v1 = object.V - switch subject := subject.(type) { - case *ast.AddressValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0) - - case *ast.IntegerValueNode: - var v1, v2 decimal.Decimal - v1 = object.V - switch subject := subject.(type) { - case *ast.IntegerValueNode: - v2 = subject.V - case *ast.DecimalValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(v1.GreaterThanOrEqual(v2)) - - case *ast.DecimalValueNode: - var v1, v2 decimal.Decimal - v1 = object.V - switch subject := subject.(type) { - case *ast.IntegerValueNode: - v2 = subject.V - case *ast.DecimalValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(v1.GreaterThanOrEqual(v2)) - - case *ast.BytesValueNode: - var v1, v2 []byte - v1 = object.V - switch subject := subject.(type) { - case *ast.BytesValueNode: - v2 = subject.V - case *ast.NullValueNode: - vo = ast.BoolValueUnknown - break eval - default: - elAppendTypeErrorOperandValueNode(el, subject, fn, op, object) - return false - } - vo = ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0) - - case *ast.NullValueNode: - switch subject := subject.(type) { - case *ast.BoolValueNode: - vo = ast.BoolValueUnknown.GreaterOrEqual(subject.V) - default: - vo = ast.BoolValueUnknown - } - - default: - panic(unknownValueNodeType(object)) - } - node := &ast.BoolValueNode{} - node.SetPosition(n.GetPosition()) - node.SetLength(n.GetLength()) - node.SetToken(n.GetToken()) - node.V = vo - r = node - return true - } if object, ok := object.(ast.Valuer); ok { if subject, ok := subject.(ast.Valuer); ok { - if !fold(object, subject) { + node := foldRelationalOperator(n, object, subject, el, fn, op, + func(v1, v2 ast.BoolValue) ast.BoolValue { + return v1.GreaterOrEqual(v2) + }, + func(v1, v2 []byte) ast.BoolValue { + if v1 == nil || v2 == nil { + return ast.BoolValueUnknown + } + return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0) + }, + func(v1, v2 decimal.NullDecimal) ast.BoolValue { + if !v1.Valid || !v2.Valid { + return ast.BoolValueUnknown + } + return ast.NewBoolValueFromBool( + v1.Decimal.GreaterThanOrEqual(v2.Decimal)) + }) + if node == nil { return nil } + r = node } } diff --git a/core/vm/sqlvm/checker/utils.go b/core/vm/sqlvm/checker/utils.go index 34b73af4a..3ca49676f 100644 --- a/core/vm/sqlvm/checker/utils.go +++ b/core/vm/sqlvm/checker/utils.go @@ -491,3 +491,53 @@ func newTypeActionAssign(expected ast.DataType) typeActionAssign { var _ typeAction = typeActionAssign{} func (typeActionAssign) ˉtypeAction() {} + +//go-sumtype:decl constantValue +type constantValue interface { + ˉconstantValue() +} + +type constantValueBool ast.BoolValue + +var _ constantValue = constantValueBool(0) + +func (constantValueBool) ˉconstantValue() {} + +func newConstantValueBool(b ast.BoolValue) constantValueBool { + return constantValueBool(b) +} + +func (b constantValueBool) GetBool() ast.BoolValue { + return ast.BoolValue(b) +} + +type constantValueBytes []byte + +var _ constantValue = constantValueBytes{} + +func (constantValueBytes) ˉconstantValue() {} + +func newConstantValueBytes(b []byte) constantValueBytes { + if b == nil { + return constantValueBytes{} + } + return constantValueBytes(b) +} + +func (b constantValueBytes) GetBytes() []byte { + return []byte(b) +} + +type constantValueDecimal decimal.NullDecimal + +var _ constantValue = constantValueDecimal{} + +func (constantValueDecimal) ˉconstantValue() {} + +func newConstantValueDecimal(d decimal.Decimal) constantValueDecimal { + return constantValueDecimal{Decimal: d, Valid: true} +} + +func (d constantValueDecimal) GetDecimal() decimal.NullDecimal { + return decimal.NullDecimal(d) +} |