aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTing-Wei Lan <lantw44@gmail.com>2019-04-29 19:11:19 +0800
committerTing-Wei Lan <tingwei.lan@cobinhood.com>2019-05-14 11:04:15 +0800
commit49acaa8391cef7bd069b965d7de827cb0ba683b7 (patch)
tree5a04654ade25bec384f59dd8220f1de7d68a4338
parent77dab86d3a7382e3a0baa79ccd39676b59850518 (diff)
downloaddexon-49acaa8391cef7bd069b965d7de827cb0ba683b7.tar.gz
dexon-49acaa8391cef7bd069b965d7de827cb0ba683b7.tar.zst
dexon-49acaa8391cef7bd069b965d7de827cb0ba683b7.zip
code backup 25
-rw-r--r--core/vm/sqlvm/checker/checker.go268
-rw-r--r--core/vm/sqlvm/checker/utils.go50
2 files changed, 212 insertions, 106 deletions
diff --git a/core/vm/sqlvm/checker/checker.go b/core/vm/sqlvm/checker/checker.go
index dce59f332..5dcfaa043 100644
--- a/core/vm/sqlvm/checker/checker.go
+++ b/core/vm/sqlvm/checker/checker.go
@@ -2052,8 +2052,8 @@ func elAppendTypeErrorOperandDataType(el *errors.ErrorList, n ast.ExprNode,
Severity: errors.ErrorSeverityError,
Prefix: fn,
Message: fmt.Sprintf(
- "cannot use %s (%04x) as an operand of %s because there is "+
- "already an operand declared as %s (%04x)",
+ "cannot use %s (%04x) with %s because the operand is expected "+
+ "to be %s (%04x)",
dtGiven.String(), uint16(dtGiven), op,
dtExpected.String(), uint16(dtExpected)),
}, nil)
@@ -2070,12 +2070,151 @@ func elAppendTypeErrorOperandValueNode(el *errors.ErrorList, n ast.Valuer,
Severity: errors.ErrorSeverityError,
Prefix: fn,
Message: fmt.Sprintf(
- "cannot use %s as an operand of %s because there is "+
- "already an operand found to be %s",
+ "cannot use %s with %s because the other operand is expected "+
+ "to be %s",
describeValueNodeType(n), op, describeValueNodeType(nExpected)),
}, nil)
}
+func extractConstantValue(n ast.Valuer) constantValue {
+ switch n := n.(type) {
+ case *ast.BoolValueNode:
+ return newConstantValueBool(n.V)
+ case *ast.AddressValueNode:
+ return newConstantValueBytes(n.V)
+ case *ast.IntegerValueNode:
+ return newConstantValueDecimal(n.V)
+ case *ast.DecimalValueNode:
+ return newConstantValueDecimal(n.V)
+ case *ast.BytesValueNode:
+ return newConstantValueBytes(n.V)
+ case *ast.NullValueNode:
+ return nil
+ default:
+ panic(unknownValueNodeType(n))
+ }
+}
+
+func unknownConstantValueType(v constantValue) string {
+ return fmt.Sprintf("unknown constant value type %T", v)
+}
+
+func foldRelationalOperator(n ast.BinaryOperator, object, subject ast.Valuer,
+ el *errors.ErrorList, fn, op string,
+ evalBool func(ast.BoolValue, ast.BoolValue) ast.BoolValue,
+ evalBytes func([]byte, []byte) ast.BoolValue,
+ evalDecimal func(decimal.NullDecimal, decimal.NullDecimal) ast.BoolValue,
+) *ast.BoolValueNode {
+
+ compatibleTypes := func() bool {
+ switch object.(type) {
+ case *ast.BoolValueNode:
+ switch subject.(type) {
+ case *ast.BoolValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.AddressValueNode:
+ switch subject.(type) {
+ case *ast.AddressValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.IntegerValueNode:
+ switch subject.(type) {
+ case *ast.IntegerValueNode:
+ case *ast.DecimalValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.DecimalValueNode:
+ switch subject.(type) {
+ case *ast.IntegerValueNode:
+ case *ast.DecimalValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.BytesValueNode:
+ switch subject.(type) {
+ case *ast.BytesValueNode:
+ case *ast.NullValueNode:
+ default:
+ return false
+ }
+ case *ast.NullValueNode:
+ default:
+ panic(unknownValueNodeType(object))
+ }
+ return true
+ }
+
+ if !compatibleTypes() {
+ elAppendTypeErrorOperandValueNode(el, subject, fn, op, object)
+ return nil
+ }
+
+ arg1 := extractConstantValue(object)
+ arg2 := extractConstantValue(subject)
+
+ var vo ast.BoolValue
+ switch v1 := arg1.(type) {
+ case constantValueBool:
+ var v2 ast.BoolValue
+ if arg2 == nil {
+ v2 = ast.BoolValueUnknown
+ } else {
+ v2 = arg2.(constantValueBool).GetBool()
+ }
+ vo = evalBool(v1.GetBool(), v2)
+
+ case constantValueBytes:
+ var v2 []byte
+ if arg2 == nil {
+ v2 = nil
+ } else {
+ v2 = arg2.(constantValueBytes).GetBytes()
+ }
+ vo = evalBytes(v1.GetBytes(), v2)
+
+ case constantValueDecimal:
+ var v2 decimal.NullDecimal
+ if arg2 == nil {
+ v2 = decimal.NullDecimal{Valid: false}
+ } else {
+ v2 = arg2.(constantValueDecimal).GetDecimal()
+ }
+ vo = evalDecimal(v1.GetDecimal(), v2)
+
+ case nil:
+ switch v2 := arg2.(type) {
+ case constantValueBool:
+ vo = evalBool(ast.BoolValueUnknown, v2.GetBool())
+ case constantValueBytes:
+ vo = evalBytes(nil, v2.GetBytes())
+ case constantValueDecimal:
+ vo = evalDecimal(decimal.NullDecimal{Valid: false}, v2.GetDecimal())
+ case nil:
+ vo = evalBool(ast.BoolValueUnknown, ast.BoolValueUnknown)
+ default:
+ panic(unknownConstantValueType(v2))
+ }
+
+ default:
+ panic(unknownConstantValueType(v1))
+ }
+
+ node := &ast.BoolValueNode{}
+ node.SetPosition(n.GetPosition())
+ node.SetLength(n.GetLength())
+ node.SetToken(n.GetToken())
+ node.V = vo
+ return node
+}
+
func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode,
s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
tr schema.TableRef, ta typeAction) ast.ExprNode {
@@ -2140,112 +2279,29 @@ func checkGreaterOrEqualOperator(n *ast.GreaterOrEqualOperatorNode,
panic("unreachable")
}
- fold := func(object, subject ast.Valuer) bool {
- var vo ast.BoolValue
- eval:
- switch object := object.(type) {
- case *ast.BoolValueNode:
- var v1, v2 ast.BoolValue
- v1 = object.V
- switch subject := subject.(type) {
- case *ast.BoolValueNode:
- v2 = subject.V
- case *ast.NullValueNode:
- v2 = ast.BoolValueUnknown
- default:
- elAppendTypeErrorOperandValueNode(el, subject, fn, op, object)
- return false
- }
- vo = v1.GreaterOrEqual(v2)
-
- case *ast.AddressValueNode:
- var v1, v2 []byte
- v1 = object.V
- switch subject := subject.(type) {
- case *ast.AddressValueNode:
- v2 = subject.V
- case *ast.NullValueNode:
- vo = ast.BoolValueUnknown
- break eval
- default:
- elAppendTypeErrorOperandValueNode(el, subject, fn, op, object)
- return false
- }
- vo = ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0)
-
- case *ast.IntegerValueNode:
- var v1, v2 decimal.Decimal
- v1 = object.V
- switch subject := subject.(type) {
- case *ast.IntegerValueNode:
- v2 = subject.V
- case *ast.DecimalValueNode:
- v2 = subject.V
- case *ast.NullValueNode:
- vo = ast.BoolValueUnknown
- break eval
- default:
- elAppendTypeErrorOperandValueNode(el, subject, fn, op, object)
- return false
- }
- vo = ast.NewBoolValueFromBool(v1.GreaterThanOrEqual(v2))
-
- case *ast.DecimalValueNode:
- var v1, v2 decimal.Decimal
- v1 = object.V
- switch subject := subject.(type) {
- case *ast.IntegerValueNode:
- v2 = subject.V
- case *ast.DecimalValueNode:
- v2 = subject.V
- case *ast.NullValueNode:
- vo = ast.BoolValueUnknown
- break eval
- default:
- elAppendTypeErrorOperandValueNode(el, subject, fn, op, object)
- return false
- }
- vo = ast.NewBoolValueFromBool(v1.GreaterThanOrEqual(v2))
-
- case *ast.BytesValueNode:
- var v1, v2 []byte
- v1 = object.V
- switch subject := subject.(type) {
- case *ast.BytesValueNode:
- v2 = subject.V
- case *ast.NullValueNode:
- vo = ast.BoolValueUnknown
- break eval
- default:
- elAppendTypeErrorOperandValueNode(el, subject, fn, op, object)
- return false
- }
- vo = ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0)
-
- case *ast.NullValueNode:
- switch subject := subject.(type) {
- case *ast.BoolValueNode:
- vo = ast.BoolValueUnknown.GreaterOrEqual(subject.V)
- default:
- vo = ast.BoolValueUnknown
- }
-
- default:
- panic(unknownValueNodeType(object))
- }
- node := &ast.BoolValueNode{}
- node.SetPosition(n.GetPosition())
- node.SetLength(n.GetLength())
- node.SetToken(n.GetToken())
- node.V = vo
- r = node
- return true
- }
if object, ok := object.(ast.Valuer); ok {
if subject, ok := subject.(ast.Valuer); ok {
- if !fold(object, subject) {
+ node := foldRelationalOperator(n, object, subject, el, fn, op,
+ func(v1, v2 ast.BoolValue) ast.BoolValue {
+ return v1.GreaterOrEqual(v2)
+ },
+ func(v1, v2 []byte) ast.BoolValue {
+ if v1 == nil || v2 == nil {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(bytes.Compare(v1, v2) >= 0)
+ },
+ func(v1, v2 decimal.NullDecimal) ast.BoolValue {
+ if !v1.Valid || !v2.Valid {
+ return ast.BoolValueUnknown
+ }
+ return ast.NewBoolValueFromBool(
+ v1.Decimal.GreaterThanOrEqual(v2.Decimal))
+ })
+ if node == nil {
return nil
}
+ r = node
}
}
diff --git a/core/vm/sqlvm/checker/utils.go b/core/vm/sqlvm/checker/utils.go
index 34b73af4a..3ca49676f 100644
--- a/core/vm/sqlvm/checker/utils.go
+++ b/core/vm/sqlvm/checker/utils.go
@@ -491,3 +491,53 @@ func newTypeActionAssign(expected ast.DataType) typeActionAssign {
var _ typeAction = typeActionAssign{}
func (typeActionAssign) ˉtypeAction() {}
+
+//go-sumtype:decl constantValue
+type constantValue interface {
+ ˉconstantValue()
+}
+
+type constantValueBool ast.BoolValue
+
+var _ constantValue = constantValueBool(0)
+
+func (constantValueBool) ˉconstantValue() {}
+
+func newConstantValueBool(b ast.BoolValue) constantValueBool {
+ return constantValueBool(b)
+}
+
+func (b constantValueBool) GetBool() ast.BoolValue {
+ return ast.BoolValue(b)
+}
+
+type constantValueBytes []byte
+
+var _ constantValue = constantValueBytes{}
+
+func (constantValueBytes) ˉconstantValue() {}
+
+func newConstantValueBytes(b []byte) constantValueBytes {
+ if b == nil {
+ return constantValueBytes{}
+ }
+ return constantValueBytes(b)
+}
+
+func (b constantValueBytes) GetBytes() []byte {
+ return []byte(b)
+}
+
+type constantValueDecimal decimal.NullDecimal
+
+var _ constantValue = constantValueDecimal{}
+
+func (constantValueDecimal) ˉconstantValue() {}
+
+func newConstantValueDecimal(d decimal.Decimal) constantValueDecimal {
+ return constantValueDecimal{Decimal: d, Valid: true}
+}
+
+func (d constantValueDecimal) GetDecimal() decimal.NullDecimal {
+ return decimal.NullDecimal(d)
+}