diff options
-rw-r--r-- | CallGraph.cpp | 105 | ||||
-rw-r--r-- | CallGraph.h | 69 | ||||
-rw-r--r-- | Compiler.cpp | 126 | ||||
-rw-r--r-- | Compiler.h | 8 | ||||
-rw-r--r-- | CompilerContext.cpp | 77 | ||||
-rw-r--r-- | CompilerContext.h | 27 | ||||
-rw-r--r-- | CompilerStack.cpp | 1 | ||||
-rw-r--r-- | ExpressionCompiler.cpp | 24 | ||||
-rw-r--r-- | GlobalContext.cpp | 9 | ||||
-rw-r--r-- | GlobalContext.h | 2 | ||||
-rw-r--r-- | Types.cpp | 20 | ||||
-rw-r--r-- | Types.h | 13 |
12 files changed, 153 insertions, 328 deletions
diff --git a/CallGraph.cpp b/CallGraph.cpp deleted file mode 100644 index 5f8fc547..00000000 --- a/CallGraph.cpp +++ /dev/null @@ -1,105 +0,0 @@ - -/* - This file is part of cpp-ethereum. - - cpp-ethereum is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - cpp-ethereum is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with cpp-ethereum. If not, see <http://www.gnu.org/licenses/>. -*/ -/** - * @author Christian <c@ethdev.com> - * @date 2014 - * Callgraph of functions inside a contract. - */ - -#include <libsolidity/AST.h> -#include <libsolidity/CallGraph.h> - -using namespace std; - -namespace dev -{ -namespace solidity -{ - -void CallGraph::addNode(ASTNode const& _node) -{ - if (!m_nodesSeen.count(&_node)) - { - m_workQueue.push(&_node); - m_nodesSeen.insert(&_node); - } -} - -set<FunctionDefinition const*> const& CallGraph::getCalls() -{ - computeCallGraph(); - return m_functionsSeen; -} - -void CallGraph::computeCallGraph() -{ - while (!m_workQueue.empty()) - { - m_workQueue.front()->accept(*this); - m_workQueue.pop(); - } -} - -bool CallGraph::visit(Identifier const& _identifier) -{ - if (auto fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration())) - { - if (m_functionOverrideResolver) - fun = (*m_functionOverrideResolver)(fun->getName()); - solAssert(fun, "Error finding override for function " + fun->getName()); - addNode(*fun); - } - if (auto modifier = dynamic_cast<ModifierDefinition const*>(_identifier.getReferencedDeclaration())) - { - if (m_modifierOverrideResolver) - modifier = (*m_modifierOverrideResolver)(modifier->getName()); - solAssert(modifier, "Error finding override for modifier " + modifier->getName()); - addNode(*modifier); - } - return true; -} - -bool CallGraph::visit(FunctionDefinition const& _function) -{ - m_functionsSeen.insert(&_function); - return true; -} - -bool CallGraph::visit(MemberAccess const& _memberAccess) -{ - // used for "BaseContract.baseContractFunction" - if (_memberAccess.getExpression().getType()->getCategory() == Type::Category::TYPE) - { - TypeType const& type = dynamic_cast<TypeType const&>(*_memberAccess.getExpression().getType()); - if (type.getMembers().getMemberType(_memberAccess.getMemberName())) - { - ContractDefinition const& contract = dynamic_cast<ContractType const&>(*type.getActualType()) - .getContractDefinition(); - for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions()) - if (function->getName() == _memberAccess.getMemberName()) - { - addNode(*function); - return true; - } - } - } - return true; -} - -} -} diff --git a/CallGraph.h b/CallGraph.h deleted file mode 100644 index 9af5cdf9..00000000 --- a/CallGraph.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - This file is part of cpp-ethereum. - - cpp-ethereum is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - cpp-ethereum is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with cpp-ethereum. If not, see <http://www.gnu.org/licenses/>. -*/ -/** - * @author Christian <c@ethdev.com> - * @date 2014 - * Callgraph of functions inside a contract. - */ - -#include <set> -#include <queue> -#include <functional> -#include <boost/range/iterator_range.hpp> -#include <libsolidity/ASTVisitor.h> - -namespace dev -{ -namespace solidity -{ - -/** - * Can be used to compute the graph of calls (or rather references) between functions of the same - * contract. Current functionality is limited to computing all functions that are directly - * or indirectly called by some functions. - */ -class CallGraph: private ASTConstVisitor -{ -public: - using FunctionOverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; - using ModifierOverrideResolver = std::function<ModifierDefinition const*(std::string const&)>; - - CallGraph(FunctionOverrideResolver const& _functionOverrideResolver, - ModifierOverrideResolver const& _modifierOverrideResolver): - m_functionOverrideResolver(&_functionOverrideResolver), - m_modifierOverrideResolver(&_modifierOverrideResolver) {} - - void addNode(ASTNode const& _node); - - std::set<FunctionDefinition const*> const& getCalls(); - -private: - virtual bool visit(FunctionDefinition const& _function) override; - virtual bool visit(Identifier const& _identifier) override; - virtual bool visit(MemberAccess const& _memberAccess) override; - - void computeCallGraph(); - - FunctionOverrideResolver const* m_functionOverrideResolver; - ModifierOverrideResolver const* m_modifierOverrideResolver; - std::set<ASTNode const*> m_nodesSeen; - std::set<FunctionDefinition const*> m_functionsSeen; - std::queue<ASTNode const*> m_workQueue; -}; - -} -} diff --git a/Compiler.cpp b/Compiler.cpp index c7656363..93784adf 100644 --- a/Compiler.cpp +++ b/Compiler.cpp @@ -28,7 +28,6 @@ #include <libsolidity/Compiler.h> #include <libsolidity/ExpressionCompiler.h> #include <libsolidity/CompilerUtils.h> -#include <libsolidity/CallGraph.h> using namespace std; @@ -40,31 +39,13 @@ void Compiler::compileContract(ContractDefinition const& _contract, { m_context = CompilerContext(); // clear it just in case initializeContext(_contract, _contracts); - - for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) - { - for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) - if (!function->isConstructor()) - m_context.addFunction(*function); - - for (ASTPointer<VariableDeclaration> const& vardecl: contract->getStateVariables()) - if (vardecl->isPublic()) - m_context.addFunction(*vardecl); - - for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) - m_context.addModifier(*modifier); - } - appendFunctionSelector(_contract); - for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + set<Declaration const*> functions = m_context.getFunctionsWithoutCode(); + while (!functions.empty()) { - for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) - if (!function->isConstructor()) - function->accept(*this); - - for (ASTPointer<VariableDeclaration> const& vardecl: contract->getStateVariables()) - if (vardecl->isPublic()) - generateAccessorCode(*vardecl); + for (Declaration const* function: functions) + function->accept(*this); + functions = m_context.getFunctionsWithoutCode(); } // Swap the runtime context with the creation-time context @@ -77,72 +58,26 @@ void Compiler::initializeContext(ContractDefinition const& _contract, map<ContractDefinition const*, bytes const*> const& _contracts) { m_context.setCompiledContracts(_contracts); + m_context.setInheritanceHierarchy(_contract.getLinearizedBaseContracts()); registerStateVariables(_contract); } void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) { - std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts(); - - // Make all modifiers known to the context. - for (ContractDefinition const* contract: bases) - for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) - m_context.addModifier(*modifier); - // arguments for base constructors, filled in derived-to-base order map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments; - set<FunctionDefinition const*> neededFunctions; - set<ASTNode const*> nodesUsedInConstructors; - // Determine the arguments that are used for the base constructors and also which functions - // are needed at compile time. + // Determine the arguments that are used for the base constructors. + std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts(); for (ContractDefinition const* contract: bases) - { - if (FunctionDefinition const* constructor = contract->getConstructor()) - nodesUsedInConstructors.insert(constructor); for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts()) { ContractDefinition const* baseContract = dynamic_cast<ContractDefinition const*>( base->getName()->getReferencedDeclaration()); solAssert(baseContract, ""); if (baseArguments.count(baseContract) == 0) - { baseArguments[baseContract] = &base->getArguments(); - for (ASTPointer<Expression> const& arg: base->getArguments()) - nodesUsedInConstructors.insert(arg.get()); - } } - } - - auto functionOverrideResolver = [&](string const& _name) -> FunctionDefinition const* - { - for (ContractDefinition const* contract: bases) - for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) - if (!function->isConstructor() && function->getName() == _name) - return function.get(); - return nullptr; - }; - auto modifierOverrideResolver = [&](string const& _name) -> ModifierDefinition const* - { - return &m_context.getFunctionModifier(_name); - }; - - neededFunctions = getFunctionsCalled(nodesUsedInConstructors, functionOverrideResolver, - modifierOverrideResolver); - - // First add all overrides (or the functions themselves if there is no override) - for (FunctionDefinition const* fun: neededFunctions) - { - FunctionDefinition const* override = nullptr; - if (!fun->isConstructor()) - override = functionOverrideResolver(fun->getName()); - if (!!override && neededFunctions.count(override)) - m_context.addFunction(*override); - } - // now add the rest - for (FunctionDefinition const* fun: neededFunctions) - if (fun->isConstructor() || functionOverrideResolver(fun->getName()) != fun) - m_context.addFunction(*fun); // Call constructors in base-to-derived order. // The Constructor for the most derived contract is called later. @@ -164,10 +99,14 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp m_context << eth::Instruction::DUP1 << sub << u256(0) << eth::Instruction::CODECOPY; m_context << u256(0) << eth::Instruction::RETURN; - // note that we have to explicitly include all used functions because of absolute jump - // labels - for (FunctionDefinition const* fun: neededFunctions) - fun->accept(*this); + // note that we have to include the functions again because of absolute jump labels + set<Declaration const*> functions = m_context.getFunctionsWithoutCode(); + while (!functions.empty()) + { + for (Declaration const* function: functions) + function->accept(*this); + functions = m_context.getFunctionsWithoutCode(); + } } void Compiler::appendBaseConstructorCall(FunctionDefinition const& _constructor, @@ -201,16 +140,6 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) m_context << returnTag; } -set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes, - function<FunctionDefinition const*(string const&)> const& _resolveFunctionOverrides, - function<ModifierDefinition const*(string const&)> const& _resolveModifierOverrides) -{ - CallGraph callgraph(_resolveFunctionOverrides, _resolveModifierOverrides); - for (ASTNode const* node: _nodes) - callgraph.addNode(*node); - return callgraph.getCalls(); -} - void Compiler::appendFunctionSelector(ContractDefinition const& _contract) { map<FixedHash<4>, FunctionDescription> interfaceFunctions = _contract.getInterfaceFunctions(); @@ -292,19 +221,22 @@ void Compiler::registerStateVariables(ContractDefinition const& _contract) m_context.addStateVariable(*variable); } -void Compiler::generateAccessorCode(VariableDeclaration const& _varDecl) +bool Compiler::visit(VariableDeclaration const& _variableDeclaration) { - m_context.startNewFunction(); + solAssert(_variableDeclaration.isStateVariable(), "Compiler visit to non-state variable declaration."); + + m_context.startFunction(_variableDeclaration); m_breakTags.clear(); m_continueTags.clear(); - m_context << m_context.getFunctionEntryLabel(_varDecl); - ExpressionCompiler::appendStateVariableAccessor(m_context, _varDecl); + m_context << m_context.getFunctionEntryLabel(_variableDeclaration); + ExpressionCompiler::appendStateVariableAccessor(m_context, _variableDeclaration); - unsigned sizeOnStack = _varDecl.getType()->getSizeOnStack(); - solAssert(sizeOnStack <= 15, "Illegal variable stack size detected"); - m_context << eth::dupInstruction(sizeOnStack + 1); - m_context << eth::Instruction::JUMP; + unsigned sizeOnStack = _variableDeclaration.getType()->getSizeOnStack(); + solAssert(sizeOnStack <= 15, "Stack too deep."); + m_context << eth::dupInstruction(sizeOnStack + 1) << eth::Instruction::JUMP; + + return false; } bool Compiler::visit(FunctionDefinition const& _function) @@ -313,7 +245,7 @@ bool Compiler::visit(FunctionDefinition const& _function) // caller puts: [retarg0] ... [retargm] [return address] [arg0] ... [argn] // although note that this reduces the size of the visible stack - m_context.startNewFunction(); + m_context.startFunction(_function); m_returnTag = m_context.newTag(); m_breakTags.clear(); m_continueTags.clear(); @@ -321,8 +253,6 @@ bool Compiler::visit(FunctionDefinition const& _function) m_currentFunction = &_function; m_modifierDepth = 0; - m_context << m_context.getFunctionEntryLabel(_function); - // stack upon entry: [return address] [arg0] [arg1] ... [argn] // reserve additional slots: [retarg0] ... [retargm] [localvar0] ... [localvarp] @@ -50,11 +50,6 @@ private: void appendBaseConstructorCall(FunctionDefinition const& _constructor, std::vector<ASTPointer<Expression>> const& _arguments); void appendConstructorCall(FunctionDefinition const& _constructor); - /// Recursively searches the call graph and returns all functions referenced inside _nodes. - /// _resolveOverride is called to resolve virtual function overrides. - std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes, - std::function<FunctionDefinition const*(std::string const&)> const& _resolveFunctionOverride, - std::function<ModifierDefinition const*(std::string const&)> const& _resolveModifierOverride); void appendFunctionSelector(ContractDefinition const& _contract); /// Creates code that unpacks the arguments for the given function represented by a vector of TypePointers. /// From memory if @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. @@ -63,8 +58,7 @@ private: void registerStateVariables(ContractDefinition const& _contract); - void generateAccessorCode(VariableDeclaration const& _varDecl); - + virtual bool visit(VariableDeclaration const& _variableDeclaration) override; virtual bool visit(FunctionDefinition const& _function) override; virtual bool visit(IfStatement const& _ifStatement) override; virtual bool visit(WhileStatement const& _whileStatement) override; diff --git a/CompilerContext.cpp b/CompilerContext.cpp index ea349c0d..52910a55 100644 --- a/CompilerContext.cpp +++ b/CompilerContext.cpp @@ -43,6 +43,14 @@ void CompilerContext::addStateVariable(VariableDeclaration const& _declaration) m_stateVariablesSize += _declaration.getType()->getStorageSize(); } +void CompilerContext::startFunction(Declaration const& _function) +{ + m_functionsWithCode.insert(&_function); + m_localVariables.clear(); + m_asm.setDeposit(0); + *this << getFunctionEntryLabel(_function); +} + void CompilerContext::addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent) { @@ -59,18 +67,6 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla *this << u256(0); } -void CompilerContext::addFunction(Declaration const& _decl) -{ - eth::AssemblyItem tag(m_asm.newTag()); - m_functionEntryLabels.insert(make_pair(&_decl, tag)); - m_virtualFunctionEntryLabels.insert(make_pair(_decl.getName(), tag)); -} - -void CompilerContext::addModifier(ModifierDefinition const& _modifier) -{ - m_functionModifiers.insert(make_pair(_modifier.getName(), &_modifier)); -} - bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const { auto ret = m_compiledContracts.find(&_contract); @@ -83,25 +79,62 @@ bool CompilerContext::isLocalVariable(Declaration const* _declaration) const return m_localVariables.count(_declaration); } -eth::AssemblyItem CompilerContext::getFunctionEntryLabel(Declaration const& _declaration) const +eth::AssemblyItem CompilerContext::getFunctionEntryLabel(Declaration const& _declaration) { auto res = m_functionEntryLabels.find(&_declaration); - solAssert(res != m_functionEntryLabels.end(), "Function entry label not found."); - return res->second.tag(); + if (res == m_functionEntryLabels.end()) + { + eth::AssemblyItem tag(m_asm.newTag()); + m_functionEntryLabels.insert(make_pair(&_declaration, tag)); + return tag.tag(); + } + else + return res->second.tag(); +} + +eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) +{ + solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); + for (ContractDefinition const* contract: m_inheritanceHierarchy) + for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) + if (!function->isConstructor() && function->getName() == _function.getName()) + return getFunctionEntryLabel(*function); + solAssert(false, "Virtual function " + _function.getName() + " not found."); + return m_asm.newTag(); // not reached +} + +eth::AssemblyItem CompilerContext::getSuperFunctionEntryLabel(string const& _name, ContractDefinition const& _base) +{ + // search for first contract after _base + solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); + auto it = find(m_inheritanceHierarchy.begin(), m_inheritanceHierarchy.end(), &_base); + solAssert(it != m_inheritanceHierarchy.end(), "Base not found in inheritance hierarchy."); + for (++it; it != m_inheritanceHierarchy.end(); ++it) + for (ASTPointer<FunctionDefinition> const& function: (*it)->getDefinedFunctions()) + if (!function->isConstructor() && function->getName() == _name) + return getFunctionEntryLabel(*function); + solAssert(false, "Super function " + _name + " not found."); + return m_asm.newTag(); // not reached } -eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const +set<Declaration const*> CompilerContext::getFunctionsWithoutCode() { - auto res = m_virtualFunctionEntryLabels.find(_function.getName()); - solAssert(res != m_virtualFunctionEntryLabels.end(), "Function entry label not found."); - return res->second.tag(); + set<Declaration const*> functions; + for (auto const& it: m_functionEntryLabels) + if (m_functionsWithCode.count(it.first) == 0) + functions.insert(it.first); + return move(functions); } ModifierDefinition const& CompilerContext::getFunctionModifier(string const& _name) const { - auto res = m_functionModifiers.find(_name); - solAssert(res != m_functionModifiers.end(), "Function modifier override not found."); - return *res->second; + solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); + for (ContractDefinition const* contract: m_inheritanceHierarchy) + for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) + if (modifier->getName() == _name) + return *modifier.get(); + BOOST_THROW_EXCEPTION(InternalCompilerError() + << errinfo_comment("Function modifier " + _name + " not found.")); } unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const diff --git a/CompilerContext.h b/CompilerContext.h index 9de3385a..6d6a65b6 100644 --- a/CompilerContext.h +++ b/CompilerContext.h @@ -41,12 +41,8 @@ class CompilerContext public: void addMagicGlobal(MagicVariableDeclaration const& _declaration); void addStateVariable(VariableDeclaration const& _declaration); - void startNewFunction() { m_localVariables.clear(); m_asm.setDeposit(0); } void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0); void addAndInitializeVariable(VariableDeclaration const& _declaration); - void addFunction(Declaration const& _decl); - /// Adds the given modifier to the list by name if the name is not present already. - void addModifier(ModifierDefinition const& _modifier); void setCompiledContracts(std::map<ContractDefinition const*, bytes const*> const& _contracts) { m_compiledContracts = _contracts; } bytes const& getCompiledContract(ContractDefinition const& _contract) const; @@ -54,13 +50,22 @@ public: void adjustStackOffset(int _adjustment) { m_asm.adjustDeposit(_adjustment); } bool isMagicGlobal(Declaration const* _declaration) const { return m_magicGlobals.count(_declaration) != 0; } - bool isFunctionDefinition(Declaration const* _declaration) const { return m_functionEntryLabels.count(_declaration) != 0; } bool isLocalVariable(Declaration const* _declaration) const; bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; } - eth::AssemblyItem getFunctionEntryLabel(Declaration const& _declaration) const; + eth::AssemblyItem getFunctionEntryLabel(Declaration const& _declaration); + void setInheritanceHierarchy(std::vector<ContractDefinition const*> const& _hierarchy) { m_inheritanceHierarchy = _hierarchy; } /// @returns the entry label of the given function and takes overrides into account. - eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const; + eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function); + /// @returns the entry label of function with the given name from the most derived class just + /// above _base in the current inheritance hierarchy. + eth::AssemblyItem getSuperFunctionEntryLabel(std::string const& _name, ContractDefinition const& _base); + /// @returns the set of functions for which we still need to generate code + std::set<Declaration const*> getFunctionsWithoutCode(); + /// Resets function specific members, inserts the function entry label and marks the function + /// as "having code". + void startFunction(Declaration const& _function); + ModifierDefinition const& getFunctionModifier(std::string const& _name) const; /// Returns the distance of the given local variable from the bottom of the stack (of the current function). unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; @@ -119,10 +124,10 @@ private: std::map<Declaration const*, unsigned> m_localVariables; /// Labels pointing to the entry points of functions. std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels; - /// Labels pointing to the entry points of function overrides. - std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels; - /// Mapping to obtain function modifiers by name. Should be filled from derived to base. - std::map<std::string, ModifierDefinition const*> m_functionModifiers; + /// Set of functions for which we did not yet generate code. + std::set<Declaration const*> m_functionsWithCode; + /// List of current inheritance hierarchy from derived to base. + std::vector<ContractDefinition const*> m_inheritanceHierarchy; }; } diff --git a/CompilerStack.cpp b/CompilerStack.cpp index 0b8218bb..3ed0d362 100644 --- a/CompilerStack.cpp +++ b/CompilerStack.cpp @@ -94,6 +94,7 @@ void CompilerStack::parse() { m_globalContext->setCurrentContract(*contract); resolver.updateDeclaration(*m_globalContext->getCurrentThis()); + resolver.updateDeclaration(*m_globalContext->getCurrentSuper()); resolver.resolveNamesAndTypes(*contract); m_contracts[contract->getName()].contract = contract; } diff --git a/ExpressionCompiler.cpp b/ExpressionCompiler.cpp index 15ee17fd..d2f709be 100644 --- a/ExpressionCompiler.cpp +++ b/ExpressionCompiler.cpp @@ -366,14 +366,22 @@ void ExpressionCompiler::endVisit(MemberAccess const& _memberAccess) case Type::Category::CONTRACT: { ContractType const& type = dynamic_cast<ContractType const&>(*_memberAccess.getExpression().getType()); - u256 identifier = type.getFunctionIdentifier(member); - if (identifier != Invalid256) + if (type.isSuper()) { - appendTypeConversion(type, IntegerType(0, IntegerType::Modifier::ADDRESS), true); - m_context << identifier; + m_context << m_context.getSuperFunctionEntryLabel(member, type.getContractDefinition()).pushTag(); break; } - // fall-through to "integer" otherwise (address) + else + { + u256 identifier = type.getFunctionIdentifier(member); + if (identifier != Invalid256) + { + appendTypeConversion(type, IntegerType(0, IntegerType::Modifier::ADDRESS), true); + m_context << identifier; + break; + } + // fall-through to "integer" otherwise (address) + } } case Type::Category::INTEGER: if (member == "balance") @@ -469,8 +477,10 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier) Declaration const* declaration = _identifier.getReferencedDeclaration(); if (MagicVariableDeclaration const* magicVar = dynamic_cast<MagicVariableDeclaration const*>(declaration)) { - if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) // must be "this" - m_context << eth::Instruction::ADDRESS; + if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) + // "this" or "super" + if (!dynamic_cast<ContractType const&>(*magicVar->getType()).isSuper()) + m_context << eth::Instruction::ADDRESS; } else if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration)) m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag(); diff --git a/GlobalContext.cpp b/GlobalContext.cpp index c7eea92d..40a498c8 100644 --- a/GlobalContext.cpp +++ b/GlobalContext.cpp @@ -83,5 +83,14 @@ MagicVariableDeclaration const* GlobalContext::getCurrentThis() const } +MagicVariableDeclaration const* GlobalContext::getCurrentSuper() const +{ + if (!m_superPointer[m_currentContract]) + m_superPointer[m_currentContract] = make_shared<MagicVariableDeclaration>( + "super", make_shared<ContractType>(*m_currentContract, true)); + return m_superPointer[m_currentContract].get(); + +} + } } diff --git a/GlobalContext.h b/GlobalContext.h index dfdc6662..f861c67d 100644 --- a/GlobalContext.h +++ b/GlobalContext.h @@ -48,6 +48,7 @@ public: GlobalContext(); void setCurrentContract(ContractDefinition const& _contract); MagicVariableDeclaration const* getCurrentThis() const; + MagicVariableDeclaration const* getCurrentSuper() const; /// @returns a vector of all implicit global declarations excluding "this". std::vector<Declaration const*> getDeclarations() const; @@ -56,6 +57,7 @@ private: std::vector<std::shared_ptr<MagicVariableDeclaration const>> m_magicVariables; ContractDefinition const* m_currentContract = nullptr; std::map<ContractDefinition const*, std::shared_ptr<MagicVariableDeclaration const>> mutable m_thisPointer; + std::map<ContractDefinition const*, std::shared_ptr<MagicVariableDeclaration const>> mutable m_superPointer; }; } @@ -450,7 +450,9 @@ bool ContractType::isImplicitlyConvertibleTo(Type const& _convertTo) const if (_convertTo.getCategory() == Category::CONTRACT) { auto const& bases = getContractDefinition().getLinearizedBaseContracts(); - return find(bases.begin(), bases.end(), + if (m_super && bases.size() <= 1) + return false; + return find(m_super ? ++bases.begin() : bases.begin(), bases.end(), &dynamic_cast<ContractType const&>(_convertTo).getContractDefinition()) != bases.end(); } return false; @@ -472,12 +474,12 @@ bool ContractType::operator==(Type const& _other) const if (_other.getCategory() != getCategory()) return false; ContractType const& other = dynamic_cast<ContractType const&>(_other); - return other.m_contract == m_contract; + return other.m_contract == m_contract && other.m_super == m_super; } string ContractType::toString() const { - return "contract " + m_contract.getName(); + return "contract " + string(m_super ? "super " : "") + m_contract.getName(); } MemberList const& ContractType::getMembers() const @@ -488,8 +490,16 @@ MemberList const& ContractType::getMembers() const // All address members and all interface functions map<string, shared_ptr<Type const>> members(IntegerType::AddressMemberList.begin(), IntegerType::AddressMemberList.end()); - for (auto const& it: m_contract.getInterfaceFunctions()) - members[it.second.getName()] = it.second.getFunctionTypeShared(); + if (m_super) + { + for (ContractDefinition const* base: m_contract.getLinearizedBaseContracts()) + for (ASTPointer<FunctionDefinition> const& function: base->getDefinedFunctions()) + if (!function->isConstructor()) + members.insert(make_pair(function->getName(), make_shared<FunctionType>(*function, true))); + } + else + for (auto const& it: m_contract.getInterfaceFunctions()) + members[it.second.getName()] = it.second.getFunctionTypeShared(); m_members.reset(new MemberList(members)); } return *m_members; @@ -277,7 +277,8 @@ class ContractType: public Type { public: virtual Category getCategory() const override { return Category::CONTRACT; } - ContractType(ContractDefinition const& _contract): m_contract(_contract) {} + explicit ContractType(ContractDefinition const& _contract, bool _super = false): + m_contract(_contract), m_super(_super) {} /// Contracts can be implicitly converted to super classes and to addresses. virtual bool isImplicitlyConvertibleTo(Type const& _convertTo) const override; /// Contracts can be converted to themselves and to integers. @@ -289,6 +290,7 @@ public: virtual MemberList const& getMembers() const override; + bool isSuper() const { return m_super; } ContractDefinition const& getContractDefinition() const { return m_contract; } /// Returns the function type of the constructor. Note that the location part of the function type @@ -301,6 +303,9 @@ public: private: ContractDefinition const& m_contract; + /// If true, it is the "super" type of the current contract, i.e. it contains only inherited + /// members. + bool m_super; /// Type of the constructor, @see getConstructorType. Lazily initialized. mutable std::shared_ptr<FunctionType const> m_constructorType; /// List of member types, will be lazy-initialized because of recursive references. @@ -314,7 +319,7 @@ class StructType: public Type { public: virtual Category getCategory() const override { return Category::STRUCT; } - StructType(StructDefinition const& _struct): m_struct(_struct) {} + explicit StructType(StructDefinition const& _struct): m_struct(_struct) {} virtual TypePointer unaryOperatorResult(Token::Value _operator) const override; virtual bool operator==(Type const& _other) const override; virtual u256 getStorageSize() const override; @@ -448,7 +453,7 @@ class TypeType: public Type { public: virtual Category getCategory() const override { return Category::TYPE; } - TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr): + explicit TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr): m_actualType(_actualType), m_currentContract(_currentContract) {} TypePointer const& getActualType() const { return m_actualType; } @@ -502,7 +507,7 @@ public: enum class Kind { BLOCK, MSG, TX }; virtual Category getCategory() const override { return Category::MAGIC; } - MagicType(Kind _kind); + explicit MagicType(Kind _kind); virtual TypePointer binaryOperatorResult(Token::Value, TypePointer const&) const override { |