aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTing-Wei Lan <tingwei.lan@cobinhood.com>2019-03-04 16:06:14 +0800
committerJhih-Ming Huang <jm.huang@cobinhood.com>2019-04-11 10:39:58 +0800
commit98ec5bb9701fb5353f34e398fcb4fbb2eeee765b (patch)
treeb1ca53e19718015ba57801f7230b4099e8355ce5
parent469175a930d74a78c99d4273502d2662dbe42986 (diff)
downloaddexon-98ec5bb9701fb5353f34e398fcb4fbb2eeee765b.tar.gz
dexon-98ec5bb9701fb5353f34e398fcb4fbb2eeee765b.tar.zst
dexon-98ec5bb9701fb5353f34e398fcb4fbb2eeee765b.zip
core: vm: sqlvm: limit the depth of AST to 1024
Since we traverse an AST by calling functions recursively, we have to protect the parser by limiting the depth of an AST.
-rw-r--r--core/vm/sqlvm/ast/constants.go5
-rw-r--r--core/vm/sqlvm/errors/errors.go4
-rw-r--r--core/vm/sqlvm/parser/parser.go48
3 files changed, 50 insertions, 7 deletions
diff --git a/core/vm/sqlvm/ast/constants.go b/core/vm/sqlvm/ast/constants.go
new file mode 100644
index 000000000..a11a182ea
--- /dev/null
+++ b/core/vm/sqlvm/ast/constants.go
@@ -0,0 +1,5 @@
+package ast
+
+// DepthLimit is the limit of AST depth used to prevent exhausting the stack
+// when traversing the tree recursively.
+const DepthLimit = 1024
diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go
index 886f3beb7..696d061c8 100644
--- a/core/vm/sqlvm/errors/errors.go
+++ b/core/vm/sqlvm/errors/errors.go
@@ -61,12 +61,14 @@ type ErrorCategory uint16
// Error category starts from 1. Zero value is invalid.
const (
ErrorCategoryNil ErrorCategory = iota
+ ErrorCategoryLimit
ErrorCategoryGrammar
ErrorCategorySemantic
ErrorCategoryRuntime
)
var errorCategoryMap = [...]string{
+ ErrorCategoryLimit: "limit",
ErrorCategoryGrammar: "grammar",
ErrorCategorySemantic: "semantic",
ErrorCategoryRuntime: "runtime",
@@ -82,6 +84,7 @@ type ErrorCode uint16
// Error code starts from 1. Zero value is invalid.
const (
ErrorCodeNil ErrorCode = iota
+ ErrorCodeDepthLimitReached
ErrorCodeParser
ErrorCodeInvalidIntegerSyntax
ErrorCodeInvalidNumberSyntax
@@ -108,6 +111,7 @@ const (
)
var errorCodeMap = [...]string{
+ ErrorCodeDepthLimitReached: "depth limit reached",
ErrorCodeParser: "parser error",
ErrorCodeInvalidIntegerSyntax: "invalid integer syntax",
ErrorCodeInvalidNumberSyntax: "invalid number syntax",
diff --git a/core/vm/sqlvm/parser/parser.go b/core/vm/sqlvm/parser/parser.go
index a90fec71c..8ed94e7aa 100644
--- a/core/vm/sqlvm/parser/parser.go
+++ b/core/vm/sqlvm/parser/parser.go
@@ -9,20 +9,40 @@ import (
"github.com/dexon-foundation/dexon/core/vm/sqlvm/parser/internal"
)
-func walkSelfFirst(n ast.Node, v func(ast.Node, []ast.Node)) {
+type visitor func(ast.Node, []ast.Node)
+
+func walkSelfFirst(n ast.Node, v visitor) bool {
+ return walkSelfFirstWithDepth(n, v, 0)
+}
+
+func walkSelfFirstWithDepth(n ast.Node, v visitor, d int) bool {
+ if d >= ast.DepthLimit {
+ return false
+ }
c := n.GetChildren()
+ r := true
v(n, c)
for i := range c {
- walkSelfFirst(c[i], v)
+ r = r && walkSelfFirstWithDepth(c[i], v, d+1)
}
+ return r
}
-func walkChildrenFirst(n ast.Node, v func(ast.Node, []ast.Node)) {
+func walkChildrenFirst(n ast.Node, v visitor) bool {
+ return walkChildrenFirstWithDepth(n, v, 0)
+}
+
+func walkChildrenFirstWithDepth(n ast.Node, v visitor, d int) bool {
+ if d >= ast.DepthLimit {
+ return false
+ }
c := n.GetChildren()
+ r := true
for i := range c {
- walkChildrenFirst(c[i], v)
+ r = r && walkChildrenFirstWithDepth(c[i], v, d+1)
}
v(n, c)
+ return r
}
// Parse parses SQL commands text and return an AST.
@@ -65,7 +85,8 @@ func Parse(b []byte) ([]ast.Node, error) {
if stmts[i] == nil {
continue
}
- walkChildrenFirst(stmts[i], func(n ast.Node, c []ast.Node) {
+ r := true
+ r = r && walkChildrenFirst(stmts[i], func(n ast.Node, c []ast.Node) {
minBegin := uint32(len(eb))
maxEnd := uint32(0)
for _, cn := range append(c, n) {
@@ -83,7 +104,7 @@ func Parse(b []byte) ([]ast.Node, error) {
n.SetPosition(minBegin)
n.SetLength(maxEnd - minBegin)
})
- walkSelfFirst(stmts[i], func(n ast.Node, _ []ast.Node) {
+ r = r && walkSelfFirst(stmts[i], func(n ast.Node, _ []ast.Node) {
begin := n.GetPosition()
end := begin + n.GetLength()
fixedBegin, ok := encMap[begin]
@@ -97,9 +118,22 @@ func Parse(b []byte) ([]ast.Node, error) {
n.SetPosition(fixedBegin)
n.SetLength(fixedEnd - fixedBegin)
})
+ if !r {
+ return nil, errors.ErrorList{
+ errors.Error{
+ Position: 0,
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeDepthLimitReached,
+ Token: "",
+ Prefix: "",
+ Message: fmt.Sprintf("reach syntax tree depth limit %d",
+ ast.DepthLimit),
+ },
+ }
+ }
}
if pigeonErr == nil {
- return stmts, pigeonErr
+ return stmts, nil
}
// Process errors.