diff options
-rw-r--r-- | AST.cpp | 7 | ||||
-rwxr-xr-x | AST.h | 6 | ||||
-rw-r--r-- | CallGraph.cpp | 9 | ||||
-rw-r--r-- | CallGraph.h | 8 | ||||
-rw-r--r-- | Compiler.cpp | 35 | ||||
-rw-r--r-- | Compiler.h | 5 | ||||
-rw-r--r-- | Parser.cpp | 7 | ||||
-rw-r--r-- | Parser.h | 2 | ||||
-rw-r--r-- | Types.cpp | 2 |
9 files changed, 58 insertions, 23 deletions
@@ -82,7 +82,7 @@ map<FixedHash<4>, FunctionDefinition const*> ContractDefinition::getInterfaceFun FunctionDefinition const* ContractDefinition::getConstructor() const { for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions) - if (f->getName() == getName()) + if (f->isConstructor()) return f.get(); return nullptr; } @@ -95,7 +95,7 @@ void ContractDefinition::checkIllegalOverrides() const for (ContractDefinition const* contract: getLinearizedBaseContracts()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) { - if (function->getName() == contract->getName()) + if (function->isConstructor()) continue; // constructors can neither be overriden nor override anything FunctionDefinition const*& override = functions[function->getName()]; if (!override) @@ -115,8 +115,7 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition: m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>()); for (ContractDefinition const* contract: getLinearizedBaseContracts()) for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions()) - if (f->isPublic() && f->getName() != contract->getName() && - functionsSeen.count(f->getName()) == 0) + if (f->isPublic() && !f->isConstructor() && functionsSeen.count(f->getName()) == 0) { functionsSeen.insert(f->getName()); FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); @@ -280,13 +280,13 @@ class FunctionDefinition: public Declaration { public: FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name, - bool _isPublic, + bool _isPublic, bool _isConstructor, ASTPointer<ASTString> const& _documentation, ASTPointer<ParameterList> const& _parameters, bool _isDeclaredConst, ASTPointer<ParameterList> const& _returnParameters, ASTPointer<Block> const& _body): - Declaration(_location, _name), m_isPublic(_isPublic), + Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor), m_parameters(_parameters), m_isDeclaredConst(_isDeclaredConst), m_returnParameters(_returnParameters), @@ -298,6 +298,7 @@ public: virtual void accept(ASTConstVisitor& _visitor) const override; bool isPublic() const { return m_isPublic; } + bool isConstructor() const { return m_isConstructor; } bool isDeclaredConst() const { return m_isDeclaredConst; } std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); } ParameterList const& getParameterList() const { return *m_parameters; } @@ -321,6 +322,7 @@ public: private: bool m_isPublic; + bool m_isConstructor; ASTPointer<ParameterList> m_parameters; bool m_isDeclaredConst; ASTPointer<ParameterList> m_returnParameters; diff --git a/CallGraph.cpp b/CallGraph.cpp index 88d874f3..8766114f 100644 --- a/CallGraph.cpp +++ b/CallGraph.cpp @@ -38,6 +38,7 @@ void CallGraph::addNode(ASTNode const& _node) set<FunctionDefinition const*> const& CallGraph::getCalls() { + computeCallGraph(); return m_functionsSeen; } @@ -45,8 +46,7 @@ void CallGraph::computeCallGraph() { while (!m_workQueue.empty()) { - FunctionDefinition const* fun = m_workQueue.front(); - fun->accept(*this); + m_workQueue.front()->accept(*this); m_workQueue.pop(); } } @@ -55,7 +55,12 @@ bool CallGraph::visit(Identifier const& _identifier) { FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()); if (fun) + { + if (m_overrideResolver) + fun = (*m_overrideResolver)(fun->getName()); + solAssert(fun, ""); addFunction(*fun); + } return true; } diff --git a/CallGraph.h b/CallGraph.h index e3558fc2..90176e7e 100644 --- a/CallGraph.h +++ b/CallGraph.h @@ -22,6 +22,7 @@ #include <set> #include <queue> +#include <functional> #include <boost/range/iterator_range.hpp> #include <libsolidity/ASTVisitor.h> @@ -38,8 +39,11 @@ namespace solidity class CallGraph: private ASTConstVisitor { public: + using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; + + CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {} + void addNode(ASTNode const& _node); - void computeCallGraph(); std::set<FunctionDefinition const*> const& getCalls(); @@ -48,8 +52,10 @@ private: virtual bool visit(Identifier const& _identifier) override; virtual bool visit(MemberAccess const& _memberAccess) override; + void computeCallGraph(); void addFunction(FunctionDefinition const& _function); + OverrideResolver const* m_overrideResolver; std::set<FunctionDefinition const*> m_functionsSeen; std::queue<FunctionDefinition const*> m_workQueue; }; diff --git a/Compiler.cpp b/Compiler.cpp index 36316b9a..5a434a71 100644 --- a/Compiler.cpp +++ b/Compiler.cpp @@ -43,13 +43,13 @@ void Compiler::compileContract(ContractDefinition const& _contract, for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) - if (function->getName() != contract->getName()) // don't add the constructor here + if (!function->isConstructor()) m_context.addFunction(*function); appendFunctionSelector(_contract); for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) - if (function->getName() != contract->getName()) // don't add the constructor here + if (!function->isConstructor()) function->accept(*this); // Swap the runtime context with the creation-time context @@ -93,11 +93,30 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp } } - //@TODO add virtual functions - neededFunctions = getFunctionsCalled(nodesUsedInConstructors); + auto overrideResolver = [&](string const& _name) -> FunctionDefinition const* + { + for (ContractDefinition const* contract: bases) + for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) + if (!function->isConstructor() && function->getName() == _name) + return function.get(); + return nullptr; + }; + + neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver); + // First add all overrides (or the functions themselves if there is no override) + for (FunctionDefinition const* fun: neededFunctions) + { + FunctionDefinition const* override = nullptr; + if (!fun->isConstructor()) + override = overrideResolver(fun->getName()); + if (!!override && neededFunctions.count(override)) + m_context.addFunction(*override); + } + // now add the rest for (FunctionDefinition const* fun: neededFunctions) - m_context.addFunction(*fun); + if (fun->isConstructor() || overrideResolver(fun->getName()) != fun) + m_context.addFunction(*fun); // Call constructors in base-to-derived order. // The Constructor for the most derived contract is called later. @@ -159,10 +178,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) m_context << returnTag; } -set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes) +set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes, + function<FunctionDefinition const*(string const&)> const& _resolveOverrides) { - // TODO this does not add virtual functions - CallGraph callgraph; + CallGraph callgraph(_resolveOverrides); for (ASTNode const* node: _nodes) callgraph.addNode(*node); return callgraph.getCalls(); @@ -21,6 +21,7 @@ */ #include <ostream> +#include <functional> #include <libsolidity/ASTVisitor.h> #include <libsolidity/CompilerContext.h> @@ -49,7 +50,9 @@ private: std::vector<ASTPointer<Expression>> const& _arguments); void appendConstructorCall(FunctionDefinition const& _constructor); /// Recursively searches the call graph and returns all functions referenced inside _nodes. - std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes); + /// _resolveOverride is called to resolve virtual function overrides. + std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes, + std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride); void appendFunctionSelector(ContractDefinition const& _contract); /// Creates code that unpacks the arguments for the given function, from memory if /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. @@ -142,7 +142,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() expectToken(Token::COLON); } else if (currentToken == Token::FUNCTION) - functions.push_back(parseFunctionDefinition(visibilityIsPublic)); + functions.push_back(parseFunctionDefinition(visibilityIsPublic, name.get())); else if (currentToken == Token::STRUCT) structs.push_back(parseStructDefinition()); else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING || @@ -178,7 +178,7 @@ ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier() return nodeFactory.createNode<InheritanceSpecifier>(name, arguments); } -ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic) +ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, ASTString const* _contractName) { ASTNodeFactory nodeFactory(*this); ASTPointer<ASTString> docstring; @@ -210,7 +210,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic) } ASTPointer<Block> block = parseBlock(); nodeFactory.setEndPositionFromNode(block); - return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, docstring, + bool const c_isConstructor = (_contractName && *name == *_contractName); + return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring, parameters, isDeclaredConst, returnParameters, block); } @@ -50,7 +50,7 @@ private: ASTPointer<ImportDirective> parseImportDirective(); ASTPointer<ContractDefinition> parseContractDefinition(); ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier(); - ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic); + ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName); ASTPointer<StructDefinition> parseStructDefinition(); ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar); ASTPointer<TypeName> parseTypeName(bool _allowVar); @@ -716,7 +716,7 @@ MemberList const& TypeType::getMembers() const // We are accessing the type of a base contract, so add all public and private // functions. Note that this does not add inherited functions on purpose. for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions()) - if (f->getName() != contract.getName()) + if (!f->isConstructor()) members[f->getName()] = make_shared<FunctionType>(*f); } m_members.reset(new MemberList(members)); |