aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorChristian <c@ethdev.com>2015-01-20 22:58:04 +0800
committerChristian <c@ethdev.com>2015-01-20 22:58:04 +0800
commitd854e56789c4bf2f5caca563dd0e660dababf181 (patch)
tree79b755313238b71fe44ae01504ac95ef6c38e338
parent417f9c03d03a9fa52665bd238a42081d55df62c9 (diff)
downloaddexon-solidity-d854e56789c4bf2f5caca563dd0e660dababf181.tar.gz
dexon-solidity-d854e56789c4bf2f5caca563dd0e660dababf181.tar.zst
dexon-solidity-d854e56789c4bf2f5caca563dd0e660dababf181.zip
Include virtual function overrides in constructor context.
-rw-r--r--AST.cpp7
-rwxr-xr-xAST.h6
-rw-r--r--CallGraph.cpp9
-rw-r--r--CallGraph.h8
-rw-r--r--Compiler.cpp35
-rw-r--r--Compiler.h5
-rw-r--r--Parser.cpp7
-rw-r--r--Parser.h2
-rw-r--r--Types.cpp2
9 files changed, 58 insertions, 23 deletions
diff --git a/AST.cpp b/AST.cpp
index 2cb738d3..82667367 100644
--- a/AST.cpp
+++ b/AST.cpp
@@ -82,7 +82,7 @@ map<FixedHash<4>, FunctionDefinition const*> ContractDefinition::getInterfaceFun
FunctionDefinition const* ContractDefinition::getConstructor() const
{
for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions)
- if (f->getName() == getName())
+ if (f->isConstructor())
return f.get();
return nullptr;
}
@@ -95,7 +95,7 @@ void ContractDefinition::checkIllegalOverrides() const
for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
{
- if (function->getName() == contract->getName())
+ if (function->isConstructor())
continue; // constructors can neither be overriden nor override anything
FunctionDefinition const*& override = functions[function->getName()];
if (!override)
@@ -115,8 +115,7 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition:
m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>());
for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions())
- if (f->isPublic() && f->getName() != contract->getName() &&
- functionsSeen.count(f->getName()) == 0)
+ if (f->isPublic() && !f->isConstructor() && functionsSeen.count(f->getName()) == 0)
{
functionsSeen.insert(f->getName());
FixedHash<4> hash(dev::sha3(f->getCanonicalSignature()));
diff --git a/AST.h b/AST.h
index 8079348c..97ff30cf 100755
--- a/AST.h
+++ b/AST.h
@@ -280,13 +280,13 @@ class FunctionDefinition: public Declaration
{
public:
FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name,
- bool _isPublic,
+ bool _isPublic, bool _isConstructor,
ASTPointer<ASTString> const& _documentation,
ASTPointer<ParameterList> const& _parameters,
bool _isDeclaredConst,
ASTPointer<ParameterList> const& _returnParameters,
ASTPointer<Block> const& _body):
- Declaration(_location, _name), m_isPublic(_isPublic),
+ Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor),
m_parameters(_parameters),
m_isDeclaredConst(_isDeclaredConst),
m_returnParameters(_returnParameters),
@@ -298,6 +298,7 @@ public:
virtual void accept(ASTConstVisitor& _visitor) const override;
bool isPublic() const { return m_isPublic; }
+ bool isConstructor() const { return m_isConstructor; }
bool isDeclaredConst() const { return m_isDeclaredConst; }
std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); }
ParameterList const& getParameterList() const { return *m_parameters; }
@@ -321,6 +322,7 @@ public:
private:
bool m_isPublic;
+ bool m_isConstructor;
ASTPointer<ParameterList> m_parameters;
bool m_isDeclaredConst;
ASTPointer<ParameterList> m_returnParameters;
diff --git a/CallGraph.cpp b/CallGraph.cpp
index 88d874f3..8766114f 100644
--- a/CallGraph.cpp
+++ b/CallGraph.cpp
@@ -38,6 +38,7 @@ void CallGraph::addNode(ASTNode const& _node)
set<FunctionDefinition const*> const& CallGraph::getCalls()
{
+ computeCallGraph();
return m_functionsSeen;
}
@@ -45,8 +46,7 @@ void CallGraph::computeCallGraph()
{
while (!m_workQueue.empty())
{
- FunctionDefinition const* fun = m_workQueue.front();
- fun->accept(*this);
+ m_workQueue.front()->accept(*this);
m_workQueue.pop();
}
}
@@ -55,7 +55,12 @@ bool CallGraph::visit(Identifier const& _identifier)
{
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration());
if (fun)
+ {
+ if (m_overrideResolver)
+ fun = (*m_overrideResolver)(fun->getName());
+ solAssert(fun, "");
addFunction(*fun);
+ }
return true;
}
diff --git a/CallGraph.h b/CallGraph.h
index e3558fc2..90176e7e 100644
--- a/CallGraph.h
+++ b/CallGraph.h
@@ -22,6 +22,7 @@
#include <set>
#include <queue>
+#include <functional>
#include <boost/range/iterator_range.hpp>
#include <libsolidity/ASTVisitor.h>
@@ -38,8 +39,11 @@ namespace solidity
class CallGraph: private ASTConstVisitor
{
public:
+ using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
+
+ CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {}
+
void addNode(ASTNode const& _node);
- void computeCallGraph();
std::set<FunctionDefinition const*> const& getCalls();
@@ -48,8 +52,10 @@ private:
virtual bool visit(Identifier const& _identifier) override;
virtual bool visit(MemberAccess const& _memberAccess) override;
+ void computeCallGraph();
void addFunction(FunctionDefinition const& _function);
+ OverrideResolver const* m_overrideResolver;
std::set<FunctionDefinition const*> m_functionsSeen;
std::queue<FunctionDefinition const*> m_workQueue;
};
diff --git a/Compiler.cpp b/Compiler.cpp
index 36316b9a..5a434a71 100644
--- a/Compiler.cpp
+++ b/Compiler.cpp
@@ -43,13 +43,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
- if (function->getName() != contract->getName()) // don't add the constructor here
+ if (!function->isConstructor())
m_context.addFunction(*function);
appendFunctionSelector(_contract);
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
- if (function->getName() != contract->getName()) // don't add the constructor here
+ if (!function->isConstructor())
function->accept(*this);
// Swap the runtime context with the creation-time context
@@ -93,11 +93,30 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
}
}
- //@TODO add virtual functions
- neededFunctions = getFunctionsCalled(nodesUsedInConstructors);
+ auto overrideResolver = [&](string const& _name) -> FunctionDefinition const*
+ {
+ for (ContractDefinition const* contract: bases)
+ for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
+ if (!function->isConstructor() && function->getName() == _name)
+ return function.get();
+ return nullptr;
+ };
+
+ neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver);
+ // First add all overrides (or the functions themselves if there is no override)
+ for (FunctionDefinition const* fun: neededFunctions)
+ {
+ FunctionDefinition const* override = nullptr;
+ if (!fun->isConstructor())
+ override = overrideResolver(fun->getName());
+ if (!!override && neededFunctions.count(override))
+ m_context.addFunction(*override);
+ }
+ // now add the rest
for (FunctionDefinition const* fun: neededFunctions)
- m_context.addFunction(*fun);
+ if (fun->isConstructor() || overrideResolver(fun->getName()) != fun)
+ m_context.addFunction(*fun);
// Call constructors in base-to-derived order.
// The Constructor for the most derived contract is called later.
@@ -159,10 +178,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
m_context << returnTag;
}
-set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes)
+set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
+ function<FunctionDefinition const*(string const&)> const& _resolveOverrides)
{
- // TODO this does not add virtual functions
- CallGraph callgraph;
+ CallGraph callgraph(_resolveOverrides);
for (ASTNode const* node: _nodes)
callgraph.addNode(*node);
return callgraph.getCalls();
diff --git a/Compiler.h b/Compiler.h
index ea05f38e..2bae6b39 100644
--- a/Compiler.h
+++ b/Compiler.h
@@ -21,6 +21,7 @@
*/
#include <ostream>
+#include <functional>
#include <libsolidity/ASTVisitor.h>
#include <libsolidity/CompilerContext.h>
@@ -49,7 +50,9 @@ private:
std::vector<ASTPointer<Expression>> const& _arguments);
void appendConstructorCall(FunctionDefinition const& _constructor);
/// Recursively searches the call graph and returns all functions referenced inside _nodes.
- std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes);
+ /// _resolveOverride is called to resolve virtual function overrides.
+ std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
+ std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride);
void appendFunctionSelector(ContractDefinition const& _contract);
/// Creates code that unpacks the arguments for the given function, from memory if
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
diff --git a/Parser.cpp b/Parser.cpp
index c0ca1abb..fb864072 100644
--- a/Parser.cpp
+++ b/Parser.cpp
@@ -142,7 +142,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
expectToken(Token::COLON);
}
else if (currentToken == Token::FUNCTION)
- functions.push_back(parseFunctionDefinition(visibilityIsPublic));
+ functions.push_back(parseFunctionDefinition(visibilityIsPublic, name.get()));
else if (currentToken == Token::STRUCT)
structs.push_back(parseStructDefinition());
else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING ||
@@ -178,7 +178,7 @@ ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier()
return nodeFactory.createNode<InheritanceSpecifier>(name, arguments);
}
-ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
+ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, ASTString const* _contractName)
{
ASTNodeFactory nodeFactory(*this);
ASTPointer<ASTString> docstring;
@@ -210,7 +210,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
}
ASTPointer<Block> block = parseBlock();
nodeFactory.setEndPositionFromNode(block);
- return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, docstring,
+ bool const c_isConstructor = (_contractName && *name == *_contractName);
+ return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring,
parameters,
isDeclaredConst, returnParameters, block);
}
diff --git a/Parser.h b/Parser.h
index 1b7a980f..5905a042 100644
--- a/Parser.h
+++ b/Parser.h
@@ -50,7 +50,7 @@ private:
ASTPointer<ImportDirective> parseImportDirective();
ASTPointer<ContractDefinition> parseContractDefinition();
ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier();
- ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic);
+ ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName);
ASTPointer<StructDefinition> parseStructDefinition();
ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar);
ASTPointer<TypeName> parseTypeName(bool _allowVar);
diff --git a/Types.cpp b/Types.cpp
index c6d8b62f..2446c513 100644
--- a/Types.cpp
+++ b/Types.cpp
@@ -716,7 +716,7 @@ MemberList const& TypeType::getMembers() const
// We are accessing the type of a base contract, so add all public and private
// functions. Note that this does not add inherited functions on purpose.
for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions())
- if (f->getName() != contract.getName())
+ if (!f->isConstructor())
members[f->getName()] = make_shared<FunctionType>(*f);
}
m_members.reset(new MemberList(members));