diff options
author | Christian <c@ethdev.com> | 2015-01-23 09:46:31 +0800 |
---|---|---|
committer | Christian <c@ethdev.com> | 2015-01-26 17:23:39 +0800 |
commit | fd5899d03806d82e228da12f8cbe151f079ea41f (patch) | |
tree | f839426ef2d921062ac8206ced3089acd318496a | |
parent | 7ded95c776717cf96e96dffb7425c86b47ad8b0e (diff) | |
download | dexon-solidity-fd5899d03806d82e228da12f8cbe151f079ea41f.tar.gz dexon-solidity-fd5899d03806d82e228da12f8cbe151f079ea41f.tar.zst dexon-solidity-fd5899d03806d82e228da12f8cbe151f079ea41f.zip |
Modifier overrides and callgraph analysis.
-rw-r--r-- | CallGraph.cpp | 35 | ||||
-rw-r--r-- | CallGraph.h | 15 | ||||
-rw-r--r-- | Compiler.cpp | 52 | ||||
-rw-r--r-- | Compiler.h | 3 | ||||
-rw-r--r-- | CompilerContext.cpp | 12 | ||||
-rw-r--r-- | CompilerContext.h | 5 |
6 files changed, 79 insertions, 43 deletions
diff --git a/CallGraph.cpp b/CallGraph.cpp index a671796b..5f8fc547 100644 --- a/CallGraph.cpp +++ b/CallGraph.cpp @@ -33,7 +33,11 @@ namespace solidity void CallGraph::addNode(ASTNode const& _node) { - _node.accept(*this); + if (!m_nodesSeen.count(&_node)) + { + m_workQueue.push(&_node); + m_nodesSeen.insert(&_node); + } } set<FunctionDefinition const*> const& CallGraph::getCalls() @@ -53,20 +57,26 @@ void CallGraph::computeCallGraph() bool CallGraph::visit(Identifier const& _identifier) { - FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()); - if (fun) + if (auto fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration())) { - if (m_overrideResolver) - fun = (*m_overrideResolver)(fun->getName()); + if (m_functionOverrideResolver) + fun = (*m_functionOverrideResolver)(fun->getName()); solAssert(fun, "Error finding override for function " + fun->getName()); - addFunction(*fun); + addNode(*fun); + } + if (auto modifier = dynamic_cast<ModifierDefinition const*>(_identifier.getReferencedDeclaration())) + { + if (m_modifierOverrideResolver) + modifier = (*m_modifierOverrideResolver)(modifier->getName()); + solAssert(modifier, "Error finding override for modifier " + modifier->getName()); + addNode(*modifier); } return true; } bool CallGraph::visit(FunctionDefinition const& _function) { - addFunction(_function); + m_functionsSeen.insert(&_function); return true; } @@ -83,7 +93,7 @@ bool CallGraph::visit(MemberAccess const& _memberAccess) for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions()) if (function->getName() == _memberAccess.getMemberName()) { - addFunction(*function); + addNode(*function); return true; } } @@ -91,14 +101,5 @@ bool CallGraph::visit(MemberAccess const& _memberAccess) return true; } -void CallGraph::addFunction(FunctionDefinition const& _function) -{ - if (!m_functionsSeen.count(&_function)) - { - m_functionsSeen.insert(&_function); - m_workQueue.push(&_function); - } -} - } } diff --git a/CallGraph.h b/CallGraph.h index 90176e7e..9af5cdf9 100644 --- a/CallGraph.h +++ b/CallGraph.h @@ -39,9 +39,13 @@ namespace solidity class CallGraph: private ASTConstVisitor { public: - using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; + using FunctionOverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; + using ModifierOverrideResolver = std::function<ModifierDefinition const*(std::string const&)>; - CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {} + CallGraph(FunctionOverrideResolver const& _functionOverrideResolver, + ModifierOverrideResolver const& _modifierOverrideResolver): + m_functionOverrideResolver(&_functionOverrideResolver), + m_modifierOverrideResolver(&_modifierOverrideResolver) {} void addNode(ASTNode const& _node); @@ -53,11 +57,12 @@ private: virtual bool visit(MemberAccess const& _memberAccess) override; void computeCallGraph(); - void addFunction(FunctionDefinition const& _function); - OverrideResolver const* m_overrideResolver; + FunctionOverrideResolver const* m_functionOverrideResolver; + ModifierOverrideResolver const* m_modifierOverrideResolver; + std::set<ASTNode const*> m_nodesSeen; std::set<FunctionDefinition const*> m_functionsSeen; - std::queue<FunctionDefinition const*> m_workQueue; + std::queue<ASTNode const*> m_workQueue; }; } diff --git a/Compiler.cpp b/Compiler.cpp index fa8eb775..99a429bc 100644 --- a/Compiler.cpp +++ b/Compiler.cpp @@ -42,9 +42,13 @@ void Compiler::compileContract(ContractDefinition const& _contract, initializeContext(_contract, _contracts); for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + { for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) if (!function->isConstructor()) m_context.addFunction(*function); + for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) + m_context.addModifier(*modifier); + } appendFunctionSelector(_contract); for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) @@ -67,6 +71,13 @@ void Compiler::initializeContext(ContractDefinition const& _contract, void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) { + std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts(); + + // Make all modifiers known to the context. + for (ContractDefinition const* contract: bases) + for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) + m_context.addModifier(*modifier); + // arguments for base constructors, filled in derived-to-base order map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments; set<FunctionDefinition const*> neededFunctions; @@ -74,10 +85,8 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp // Determine the arguments that are used for the base constructors and also which functions // are needed at compile time. - std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts(); for (ContractDefinition const* contract: bases) { - //TODO include modifiers if (FunctionDefinition const* constructor = contract->getConstructor()) nodesUsedInConstructors.insert(constructor); for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts()) @@ -94,7 +103,7 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp } } - auto overrideResolver = [&](string const& _name) -> FunctionDefinition const* + auto functionOverrideResolver = [&](string const& _name) -> FunctionDefinition const* { for (ContractDefinition const* contract: bases) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) @@ -102,21 +111,26 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp return function.get(); return nullptr; }; + auto modifierOverrideResolver = [&](string const& _name) -> ModifierDefinition const* + { + return &m_context.getFunctionModifier(_name); + }; - neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver); + neededFunctions = getFunctionsCalled(nodesUsedInConstructors, functionOverrideResolver, + modifierOverrideResolver); // First add all overrides (or the functions themselves if there is no override) for (FunctionDefinition const* fun: neededFunctions) { FunctionDefinition const* override = nullptr; if (!fun->isConstructor()) - override = overrideResolver(fun->getName()); + override = functionOverrideResolver(fun->getName()); if (!!override && neededFunctions.count(override)) m_context.addFunction(*override); } // now add the rest for (FunctionDefinition const* fun: neededFunctions) - if (fun->isConstructor() || overrideResolver(fun->getName()) != fun) + if (fun->isConstructor() || functionOverrideResolver(fun->getName()) != fun) m_context.addFunction(*fun); // Call constructors in base-to-derived order. @@ -176,9 +190,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) } set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes, - function<FunctionDefinition const*(string const&)> const& _resolveOverrides) + function<FunctionDefinition const*(string const&)> const& _resolveFunctionOverrides, + function<ModifierDefinition const*(string const&)> const& _resolveModifierOverrides) { - CallGraph callgraph(_resolveOverrides); + CallGraph callgraph(_resolveFunctionOverrides, _resolveModifierOverrides); for (ASTNode const* node: _nodes) callgraph.addNode(*node); return callgraph.getCalls(); @@ -471,25 +486,22 @@ void Compiler::appendModifierOrFunctionCode() { ASTPointer<ModifierInvocation> const& modifierInvocation = m_currentFunction->getModifiers()[m_modifierDepth]; - // TODO get the most derived override of the modifier - ModifierDefinition const* modifier = dynamic_cast<ModifierDefinition const*>( - modifierInvocation->getName()->getReferencedDeclaration()); - solAssert(!!modifier, "Modifier not found."); - solAssert(modifier->getParameters().size() == modifierInvocation->getArguments().size(), ""); - for (unsigned i = 0; i < modifier->getParameters().size(); ++i) + ModifierDefinition const& modifier = m_context.getFunctionModifier(modifierInvocation->getName()->getName()); + solAssert(modifier.getParameters().size() == modifierInvocation->getArguments().size(), ""); + for (unsigned i = 0; i < modifier.getParameters().size(); ++i) { - m_context.addVariable(*modifier->getParameters()[i]); + m_context.addVariable(*modifier.getParameters()[i]); compileExpression(*modifierInvocation->getArguments()[i], - modifier->getParameters()[i]->getType()); + modifier.getParameters()[i]->getType()); } - for (VariableDeclaration const* localVariable: modifier->getLocalVariables()) + for (VariableDeclaration const* localVariable: modifier.getLocalVariables()) m_context.addAndInitializeVariable(*localVariable); - unsigned const c_stackSurplus = CompilerUtils::getSizeOnStack(modifier->getParameters()) + - CompilerUtils::getSizeOnStack(modifier->getLocalVariables()); + unsigned const c_stackSurplus = CompilerUtils::getSizeOnStack(modifier.getParameters()) + + CompilerUtils::getSizeOnStack(modifier.getLocalVariables()); m_stackCleanupForReturn += c_stackSurplus; - modifier->getBody().accept(*this); + modifier.getBody().accept(*this); for (unsigned i = 0; i < c_stackSurplus; ++i) m_context << eth::Instruction::POP; @@ -53,7 +53,8 @@ private: /// Recursively searches the call graph and returns all functions referenced inside _nodes. /// _resolveOverride is called to resolve virtual function overrides. std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes, - std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride); + std::function<FunctionDefinition const*(std::string const&)> const& _resolveFunctionOverride, + std::function<ModifierDefinition const*(std::string const&)> const& _resolveModifierOverride); void appendFunctionSelector(ContractDefinition const& _contract); /// Creates code that unpacks the arguments for the given function, from memory if /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. diff --git a/CompilerContext.cpp b/CompilerContext.cpp index 96d1840e..ad1877ba 100644 --- a/CompilerContext.cpp +++ b/CompilerContext.cpp @@ -66,6 +66,11 @@ void CompilerContext::addFunction(FunctionDefinition const& _function) m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag)); } +void CompilerContext::addModifier(ModifierDefinition const& _modifier) +{ + m_functionModifiers.insert(make_pair(_modifier.getName(), &_modifier)); +} + bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const { auto ret = m_compiledContracts.find(&_contract); @@ -92,6 +97,13 @@ eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefiniti return res->second.tag(); } +ModifierDefinition const& CompilerContext::getFunctionModifier(string const& _name) const +{ + auto res = m_functionModifiers.find(_name); + solAssert(res != m_functionModifiers.end(), "Function modifier override not found."); + return *res->second; +} + unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const { auto res = m_localVariables.find(&_declaration); diff --git a/CompilerContext.h b/CompilerContext.h index 39d8c6f6..d82dfe51 100644 --- a/CompilerContext.h +++ b/CompilerContext.h @@ -45,6 +45,8 @@ public: void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0); void addAndInitializeVariable(VariableDeclaration const& _declaration); void addFunction(FunctionDefinition const& _function); + /// Adds the given modifier to the list by name if the name is not present already. + void addModifier(ModifierDefinition const& _modifier); void setCompiledContracts(std::map<ContractDefinition const*, bytes const*> const& _contracts) { m_compiledContracts = _contracts; } bytes const& getCompiledContract(ContractDefinition const& _contract) const; @@ -59,6 +61,7 @@ public: eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const; /// @returns the entry label of the given function and takes overrides into account. eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const; + ModifierDefinition const& getFunctionModifier(std::string const& _name) const; /// Returns the distance of the given local variable from the bottom of the stack (of the current function). unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; /// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns @@ -118,6 +121,8 @@ private: std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels; /// Labels pointing to the entry points of function overrides. std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels; + /// Mapping to obtain function modifiers by name. Should be filled from derived to base. + std::map<std::string, ModifierDefinition const*> m_functionModifiers; }; } |