diff options
-rw-r--r-- | AST.cpp | 330 | ||||
-rw-r--r-- | AST.h | 63 | ||||
-rw-r--r-- | Compiler.cpp | 6 | ||||
-rw-r--r-- | CompilerContext.cpp | 42 | ||||
-rw-r--r-- | CompilerContext.h | 11 | ||||
-rw-r--r-- | DeclarationContainer.cpp | 29 | ||||
-rw-r--r-- | DeclarationContainer.h | 2 | ||||
-rw-r--r-- | ExpressionCompiler.cpp | 46 | ||||
-rw-r--r-- | LValue.cpp | 8 | ||||
-rw-r--r-- | NameAndTypeResolver.cpp | 80 | ||||
-rw-r--r-- | NameAndTypeResolver.h | 21 | ||||
-rw-r--r-- | Types.cpp | 117 | ||||
-rw-r--r-- | Types.h | 57 |
13 files changed, 507 insertions, 305 deletions
@@ -52,6 +52,7 @@ void ContractDefinition::checkTypeRequirements() for (ASTPointer<InheritanceSpecifier> const& baseSpecifier: getBaseContracts()) baseSpecifier->checkTypeRequirements(); + checkDuplicateFunctions(); checkIllegalOverrides(); checkAbstractFunctions(); @@ -82,20 +83,11 @@ void ContractDefinition::checkTypeRequirements() for (ASTPointer<FunctionDefinition> const& function: getDefinedFunctions()) function->checkTypeRequirements(); - // check for duplicate declaration - set<string> functions; - for (ASTPointer<FunctionDefinition> const& function: getDefinedFunctions()) - { - string signature = function->getCanonicalSignature(); - if (functions.count(signature)) - BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(function->getLocation()) - << errinfo_comment("Duplicate functions are not allowed.")); - functions.insert(signature); - } for (ASTPointer<VariableDeclaration> const& variable: m_stateVariables) variable->checkTypeRequirements(); + checkExternalTypeClashes(); // check for hash collisions in function signatures set<FixedHash<4>> hashes; for (auto const& it: getInterfaceFunctionList()) @@ -140,6 +132,33 @@ FunctionDefinition const* ContractDefinition::getFallbackFunction() const return nullptr; } +void ContractDefinition::checkDuplicateFunctions() const +{ + /// Checks that two functions with the same name defined in this contract have different + /// argument types and that there is at most one constructor. + map<string, vector<FunctionDefinition const*>> functions; + for (ASTPointer<FunctionDefinition> const& function: getDefinedFunctions()) + functions[function->getName()].push_back(function.get()); + if (functions[getName()].size() > 1) + BOOST_THROW_EXCEPTION( + DeclarationError() << + errinfo_sourceLocation(getLocation()) << + errinfo_comment("More than one constructor defined.") + ); + for (auto const& it: functions) + { + vector<FunctionDefinition const*> const& overloads = it.second; + for (size_t i = 0; i < overloads.size(); ++i) + for (size_t j = i + 1; j < overloads.size(); ++j) + if (FunctionType(*overloads[i]).hasEqualArgumentTypes(FunctionType(*overloads[j]))) + BOOST_THROW_EXCEPTION( + DeclarationError() << + errinfo_sourceLocation(overloads[j]->getLocation()) << + errinfo_comment("Function with same name and arguments already defined.") + ); + } +} + void ContractDefinition::checkAbstractFunctions() { map<string, bool> functions; @@ -166,8 +185,7 @@ void ContractDefinition::checkIllegalOverrides() const { // TODO unify this at a later point. for this we need to put the constness and the access specifier // into the types - map<string, FunctionDefinition const*> functions; - set<string> functionNames; + map<string, vector<FunctionDefinition const*>> functions; map<string, ModifierDefinition const*> modifiers; // We search from derived to base, so the stored item causes the error. @@ -180,14 +198,21 @@ void ContractDefinition::checkIllegalOverrides() const string const& name = function->getName(); if (modifiers.count(name)) BOOST_THROW_EXCEPTION(modifiers[name]->createTypeError("Override changes function to modifier.")); - FunctionDefinition const*& override = functions[function->getCanonicalSignature()]; - functionNames.insert(name); - if (!override) - override = function.get(); - else if (override->getVisibility() != function->getVisibility() || - override->isDeclaredConst() != function->isDeclaredConst() || - FunctionType(*override) != FunctionType(*function)) - BOOST_THROW_EXCEPTION(override->createTypeError("Override changes extended function signature.")); + FunctionType functionType(*function); + // function should not change the return type + for (FunctionDefinition const* overriding: functions[name]) + { + FunctionType overridingType(*overriding); + if (!overridingType.hasEqualArgumentTypes(functionType)) + continue; + if ( + overriding->getVisibility() != function->getVisibility() || + overriding->isDeclaredConst() != function->isDeclaredConst() || + overridingType != functionType + ) + BOOST_THROW_EXCEPTION(overriding->createTypeError("Override changes extended function signature.")); + } + functions[name].push_back(function.get()); } for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) { @@ -197,12 +222,43 @@ void ContractDefinition::checkIllegalOverrides() const override = modifier.get(); else if (ModifierType(*override) != ModifierType(*modifier)) BOOST_THROW_EXCEPTION(override->createTypeError("Override changes modifier signature.")); - if (functionNames.count(name)) + if (!functions[name].empty()) BOOST_THROW_EXCEPTION(override->createTypeError("Override changes modifier to function.")); } } } +void ContractDefinition::checkExternalTypeClashes() const +{ + map<string, vector<pair<Declaration const*, shared_ptr<FunctionType>>>> externalDeclarations; + for (ContractDefinition const* contract: getLinearizedBaseContracts()) + { + for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions()) + if (f->isPartOfExternalInterface()) + { + auto functionType = make_shared<FunctionType>(*f); + externalDeclarations[functionType->externalSignature(f->getName())].push_back( + make_pair(f.get(), functionType) + ); + } + for (ASTPointer<VariableDeclaration> const& v: contract->getStateVariables()) + if (v->isPartOfExternalInterface()) + { + auto functionType = make_shared<FunctionType>(*v); + externalDeclarations[functionType->externalSignature(v->getName())].push_back( + make_pair(v.get(), functionType) + ); + } + } + for (auto const& it: externalDeclarations) + for (size_t i = 0; i < it.second.size(); ++i) + for (size_t j = i + 1; j < it.second.size(); ++j) + if (!it.second[i].second->hasEqualArgumentTypes(*it.second[j].second)) + BOOST_THROW_EXCEPTION(it.second[j].first->createTypeError( + "Function overload clash during conversion to external types for arguments." + )); +} + std::vector<ASTPointer<EventDefinition>> const& ContractDefinition::getInterfaceEvents() const { if (!m_interfaceEvents) @@ -291,11 +347,11 @@ TypePointer EnumValue::getType(ContractDefinition const*) const void InheritanceSpecifier::checkTypeRequirements() { - m_baseName->checkTypeRequirements(); + m_baseName->checkTypeRequirements(nullptr); for (ASTPointer<Expression> const& argument: m_arguments) - argument->checkTypeRequirements(); + argument->checkTypeRequirements(nullptr); - ContractDefinition const* base = dynamic_cast<ContractDefinition const*>(m_baseName->getReferencedDeclaration()); + ContractDefinition const* base = dynamic_cast<ContractDefinition const*>(&m_baseName->getReferencedDeclaration()); solAssert(base, "Base contract not available."); TypePointers parameterTypes = ContractType(*base).getConstructorType()->getParameterTypes(); if (parameterTypes.size() != m_arguments.size()) @@ -409,11 +465,7 @@ void VariableDeclaration::checkTypeRequirements() else { // no type declared and no previous assignment, infer the type - Identifier* identifier = dynamic_cast<Identifier*>(m_value.get()); - if (identifier) - identifier->checkTypeRequirementsFromVariableDeclaration(); - else - m_value->checkTypeRequirements(); + m_value->checkTypeRequirements(nullptr); TypePointer type = m_value->getType(); if (type->getCategory() == Type::Category::IntegerConstant) @@ -452,11 +504,15 @@ void ModifierDefinition::checkTypeRequirements() void ModifierInvocation::checkTypeRequirements(vector<ASTPointer<InheritanceSpecifier>> const& _bases) { - m_modifierName->checkTypeRequirements(); + TypePointers argumentTypes; for (ASTPointer<Expression> const& argument: m_arguments) - argument->checkTypeRequirements(); + { + argument->checkTypeRequirements(nullptr); + argumentTypes.push_back(argument->getType()); + } + m_modifierName->checkTypeRequirements(&argumentTypes); - auto declaration = m_modifierName->getReferencedDeclaration(); + auto const* declaration = &m_modifierName->getReferencedDeclaration(); vector<ASTPointer<VariableDeclaration>> emptyParameterList; vector<ASTPointer<VariableDeclaration>> const* parameters = nullptr; if (auto modifier = dynamic_cast<ModifierDefinition const*>(declaration)) @@ -464,7 +520,7 @@ void ModifierInvocation::checkTypeRequirements(vector<ASTPointer<InheritanceSpec else // check parameters for Base constructors for (auto const& base: _bases) - if (declaration == base->getName()->getReferencedDeclaration()) + if (declaration == &base->getName()->getReferencedDeclaration()) { if (auto referencedConstructor = dynamic_cast<ContractDefinition const&>(*declaration).getConstructor()) parameters = &referencedConstructor->getParameters(); @@ -547,9 +603,9 @@ void VariableDeclarationStatement::checkTypeRequirements() m_variable->checkTypeRequirements(); } -void Assignment::checkTypeRequirements() +void Assignment::checkTypeRequirements(TypePointers const*) { - m_leftHandSide->checkTypeRequirements(); + m_leftHandSide->checkTypeRequirements(nullptr); m_leftHandSide->requireLValue(); if (m_leftHandSide->getType()->getCategory() == Type::Category::Mapping) BOOST_THROW_EXCEPTION(createTypeError("Mappings cannot be assigned to.")); @@ -559,7 +615,7 @@ void Assignment::checkTypeRequirements() else { // compound assignment - m_rightHandSide->checkTypeRequirements(); + m_rightHandSide->checkTypeRequirements(nullptr); TypePointer resultType = m_type->binaryOperatorResult(Token::AssignmentToBinaryOp(m_assigmentOperator), m_rightHandSide->getType()); if (!resultType || *resultType != *m_type) @@ -572,7 +628,7 @@ void Assignment::checkTypeRequirements() void ExpressionStatement::checkTypeRequirements() { - m_expression->checkTypeRequirements(); + m_expression->checkTypeRequirements(nullptr); if (m_expression->getType()->getCategory() == Type::Category::IntegerConstant) if (!dynamic_pointer_cast<IntegerConstantType const>(m_expression->getType())->getIntegerType()) BOOST_THROW_EXCEPTION(m_expression->createTypeError("Invalid integer constant.")); @@ -580,7 +636,7 @@ void ExpressionStatement::checkTypeRequirements() void Expression::expectType(Type const& _expectedType) { - checkTypeRequirements(); + checkTypeRequirements(nullptr); Type const& type = *getType(); if (!type.isImplicitlyConvertibleTo(_expectedType)) BOOST_THROW_EXCEPTION(createTypeError("Type " + type.toString() + @@ -595,10 +651,10 @@ void Expression::requireLValue() m_lvalueRequested = true; } -void UnaryOperation::checkTypeRequirements() +void UnaryOperation::checkTypeRequirements(TypePointers const*) { // Inc, Dec, Add, Sub, Not, BitNot, Delete - m_subExpression->checkTypeRequirements(); + m_subExpression->checkTypeRequirements(nullptr); if (m_operator == Token::Value::Inc || m_operator == Token::Value::Dec || m_operator == Token::Value::Delete) m_subExpression->requireLValue(); m_type = m_subExpression->getType()->unaryOperatorResult(m_operator); @@ -606,10 +662,10 @@ void UnaryOperation::checkTypeRequirements() BOOST_THROW_EXCEPTION(createTypeError("Unary operator not compatible with type.")); } -void BinaryOperation::checkTypeRequirements() +void BinaryOperation::checkTypeRequirements(TypePointers const*) { - m_left->checkTypeRequirements(); - m_right->checkTypeRequirements(); + m_left->checkTypeRequirements(nullptr); + m_right->checkTypeRequirements(nullptr); m_commonType = m_left->getType()->binaryOperatorResult(m_operator, m_right->getType()); if (!m_commonType) BOOST_THROW_EXCEPTION(createTypeError("Operator " + string(Token::toString(m_operator)) + @@ -619,17 +675,22 @@ void BinaryOperation::checkTypeRequirements() m_type = Token::isCompareOp(m_operator) ? make_shared<BoolType>() : m_commonType; } -void FunctionCall::checkTypeRequirements() +void FunctionCall::checkTypeRequirements(TypePointers const*) { - // we need to check arguments' type first as their info will be used by m_express(Identifier). + bool isPositionalCall = m_names.empty(); + + // we need to check arguments' type first as they will be forwarded to + // m_expression->checkTypeRequirements + TypePointers argumentTypes; for (ASTPointer<Expression> const& argument: m_arguments) - argument->checkTypeRequirements(); + { + argument->checkTypeRequirements(nullptr); + // only store them for positional calls + if (isPositionalCall) + argumentTypes.push_back(argument->getType()); + } - auto identifier = dynamic_cast<Identifier*>(m_expression.get()); - if (identifier) - identifier->checkTypeRequirementsWithFunctionCall(*this); - else - m_expression->checkTypeRequirements(); + m_expression->checkTypeRequirements(isPositionalCall ? &argumentTypes : nullptr); Type const* expressionType = m_expression->getType().get(); if (isTypeConversion()) @@ -639,7 +700,7 @@ void FunctionCall::checkTypeRequirements() // number of non-mapping members if (m_arguments.size() != 1) BOOST_THROW_EXCEPTION(createTypeError("More than one argument for explicit type conversion.")); - if (!m_names.empty()) + if (!isPositionalCall) BOOST_THROW_EXCEPTION(createTypeError("Type conversion cannot allow named arguments.")); if (!m_arguments.front()->getType()->isExplicitlyConvertibleTo(*type.getActualType())) BOOST_THROW_EXCEPTION(createTypeError("Explicit type conversion not allowed.")); @@ -654,8 +715,9 @@ void FunctionCall::checkTypeRequirements() if (!functionType->takesArbitraryParameters() && parameterTypes.size() != m_arguments.size()) BOOST_THROW_EXCEPTION(createTypeError("Wrong argument count for function call.")); - if (m_names.empty()) + if (isPositionalCall) { + // call by positional arguments for (size_t i = 0; i < m_arguments.size(); ++i) if (!functionType->takesArbitraryParameters() && !m_arguments[i]->getType()->isImplicitlyConvertibleTo(*parameterTypes[i])) @@ -663,6 +725,7 @@ void FunctionCall::checkTypeRequirements() } else { + // call by named arguments if (functionType->takesArbitraryParameters()) BOOST_THROW_EXCEPTION(createTypeError("Named arguments cannnot be used for functions " "that take arbitrary parameters.")); @@ -700,27 +763,6 @@ void FunctionCall::checkTypeRequirements() else m_type = functionType->getReturnParameterTypes().front(); } - else if (OverloadedFunctionType const* overloadedTypes = dynamic_cast<OverloadedFunctionType const*>(expressionType)) - { - // this only applies to "x(3)" where x is assigned by "var x = f;" where f is an overloaded functions. - auto identifier = dynamic_cast<Identifier*>(m_expression.get()); - solAssert(identifier, "only applies to 'var x = f;'"); - - Declaration const* function = overloadedTypes->getIdentifier()->overloadResolution(*this); - if (!function) - BOOST_THROW_EXCEPTION(createTypeError("Can't resolve declarations")); - - identifier->setReferencedDeclaration(*function); - identifier->checkTypeRequirements(); - - TypePointer type = identifier->getType(); - FunctionType const* functionType = dynamic_cast<FunctionType const*>(type.get()); - - if (functionType->getReturnParameterTypes().empty()) - m_type = make_shared<VoidType>(); - else - m_type = functionType->getReturnParameterTypes().front(); - } else BOOST_THROW_EXCEPTION(createTypeError("Type is not callable.")); } @@ -730,10 +772,10 @@ bool FunctionCall::isTypeConversion() const return m_expression->getType()->getCategory() == Type::Category::TypeType; } -void NewExpression::checkTypeRequirements() +void NewExpression::checkTypeRequirements(TypePointers const*) { - m_contractName->checkTypeRequirements(); - m_contract = dynamic_cast<ContractDefinition const*>(m_contractName->getReferencedDeclaration()); + m_contractName->checkTypeRequirements(nullptr); + m_contract = dynamic_cast<ContractDefinition const*>(&m_contractName->getReferencedDeclaration()); if (!m_contract) BOOST_THROW_EXCEPTION(createTypeError("Identifier is not a contract.")); if (!m_contract->isFullyImplemented()) @@ -744,15 +786,37 @@ void NewExpression::checkTypeRequirements() FunctionType::Location::Creation); } -void MemberAccess::checkTypeRequirements() +void MemberAccess::checkTypeRequirements(TypePointers const* _argumentTypes) { - m_expression->checkTypeRequirements(); + m_expression->checkTypeRequirements(nullptr); Type const& type = *m_expression->getType(); - m_type = type.getMemberType(*m_memberName); - if (!m_type) - BOOST_THROW_EXCEPTION(createTypeError("Member \"" + *m_memberName + "\" not found or not " - "visible in " + type.toString())); - // This should probably move somewhere else. + + MemberList::MemberMap possibleMembers = type.getMembers().membersByName(*m_memberName); + if (possibleMembers.size() > 1 && _argumentTypes) + { + // do override resolution + for (auto it = possibleMembers.begin(); it != possibleMembers.end();) + if ( + it->type->getCategory() == Type::Category::Function && + !dynamic_cast<FunctionType const&>(*it->type).canTakeArguments(*_argumentTypes) + ) + it = possibleMembers.erase(it); + else + ++it; + } + if (possibleMembers.size() == 0) + BOOST_THROW_EXCEPTION(createTypeError( + "Member \"" + *m_memberName + "\" not found or not visible " + "after argument-dependent lookup in " + type.toString() + )); + else if (possibleMembers.size() > 1) + BOOST_THROW_EXCEPTION(createTypeError( + "Member \"" + *m_memberName + "\" not unique " + "after argument-dependent lookup in " + type.toString() + )); + + m_referencedDeclaration = possibleMembers.front().declaration; + m_type = possibleMembers.front().type; if (type.getCategory() == Type::Category::Struct) m_isLValue = true; else if (type.getCategory() == Type::Category::Array) @@ -765,9 +829,9 @@ void MemberAccess::checkTypeRequirements() m_isLValue = false; } -void IndexAccess::checkTypeRequirements() +void IndexAccess::checkTypeRequirements(TypePointers const*) { - m_base->checkTypeRequirements(); + m_base->checkTypeRequirements(nullptr); switch (m_base->getType()->getCategory()) { case Type::Category::Array: @@ -800,7 +864,7 @@ void IndexAccess::checkTypeRequirements() m_type = make_shared<TypeType>(make_shared<ArrayType>(ArrayType::Location::Memory, type.getActualType())); else { - m_index->checkTypeRequirements(); + m_index->checkTypeRequirements(nullptr); auto length = dynamic_cast<IntegerConstantType const*>(m_index->getType().get()); if (!length) BOOST_THROW_EXCEPTION(m_index->createTypeError("Integer constant expected.")); @@ -815,89 +879,57 @@ void IndexAccess::checkTypeRequirements() } } -void Identifier::checkTypeRequirementsWithFunctionCall(FunctionCall const& _functionCall) -{ - solAssert(m_referencedDeclaration || !m_overloadedDeclarations.empty(), "Identifier not resolved."); - - if (!m_referencedDeclaration) - setReferencedDeclaration(*overloadResolution(_functionCall)); - - checkTypeRequirements(); -} - -void Identifier::checkTypeRequirementsFromVariableDeclaration() +void Identifier::checkTypeRequirements(TypePointers const* _argumentTypes) { - solAssert(m_referencedDeclaration || !m_overloadedDeclarations.empty(), "Identifier not resolved."); - if (!m_referencedDeclaration) - m_type = make_shared<OverloadedFunctionType>(this); - else - checkTypeRequirements(); - - m_isLValue = true; -} - -void Identifier::checkTypeRequirements() -{ - solAssert(m_referencedDeclaration, "Identifier not resolved."); - + { + if (!_argumentTypes) + BOOST_THROW_EXCEPTION(createTypeError("Unable to determine overloaded type.")); + overloadResolution(*_argumentTypes); + } + solAssert(!!m_referencedDeclaration, "Referenced declaration is null after overload resolution."); m_isLValue = m_referencedDeclaration->isLValue(); m_type = m_referencedDeclaration->getType(m_currentContract); if (!m_type) BOOST_THROW_EXCEPTION(createTypeError("Declaration referenced before type could be determined.")); } -Declaration const* Identifier::overloadResolution(FunctionCall const& _functionCall) +Declaration const& Identifier::getReferencedDeclaration() const +{ + solAssert(!!m_referencedDeclaration, "Identifier not resolved."); + return *m_referencedDeclaration; +} + +void Identifier::overloadResolution(TypePointers const& _argumentTypes) { - solAssert(m_overloadedDeclarations.size() > 1, "FunctionIdentifier not resolved."); solAssert(!m_referencedDeclaration, "Referenced declaration should be null before overload resolution."); + solAssert(!m_overloadedDeclarations.empty(), "No candidates for overload resolution found."); - std::vector<ASTPointer<Expression const>> arguments = _functionCall.getArguments(); - std::vector<ASTPointer<ASTString>> const& argumentNames = _functionCall.getNames(); + std::vector<Declaration const*> possibles; + if (m_overloadedDeclarations.size() == 1) + m_referencedDeclaration = *m_overloadedDeclarations.begin(); - if (argumentNames.empty()) + for (Declaration const* declaration: m_overloadedDeclarations) { - // positional arguments - std::vector<Declaration const*> possibles; - for (Declaration const* declaration: m_overloadedDeclarations) - { - TypePointer const& function = declaration->getType(); - auto const& functionType = dynamic_cast<FunctionType const&>(*function); - TypePointers const& parameterTypes = functionType.getParameterTypes(); - - if (functionType.takesArbitraryParameters() || - (arguments.size() == parameterTypes.size() && - std::equal(arguments.cbegin(), arguments.cend(), parameterTypes.cbegin(), - [](ASTPointer<Expression const> const& argument, TypePointer const& parameterType) - { - return argument->getType()->isImplicitlyConvertibleTo(*parameterType); - }))) - possibles.push_back(declaration); - } - if (possibles.empty()) - BOOST_THROW_EXCEPTION(createTypeError("Can't resolve identifier")); - else if (std::none_of(possibles.cbegin() + 1, possibles.cend(), - [&possibles](Declaration const* declaration) - { - return declaration->getScope() == possibles.front()->getScope(); - })) - return possibles.front(); - else - BOOST_THROW_EXCEPTION(createTypeError("Can't resolve identifier")); + TypePointer const& function = declaration->getType(); + auto const* functionType = dynamic_cast<FunctionType const*>(function.get()); + if (functionType && functionType->canTakeArguments(_argumentTypes)) + possibles.push_back(declaration); } + if (possibles.size() == 1) + m_referencedDeclaration = possibles.front(); + else if (possibles.empty()) + BOOST_THROW_EXCEPTION(createTypeError("No matching declaration found after argument-dependent lookup.")); else - // named arguments - // TODO: don't support right now - BOOST_THROW_EXCEPTION(createTypeError("Named arguments with overloaded functions are not supported yet.")); - return nullptr; + BOOST_THROW_EXCEPTION(createTypeError("No unique declaration found after argument-dependent lookup.")); } -void ElementaryTypeNameExpression::checkTypeRequirements() +void ElementaryTypeNameExpression::checkTypeRequirements(TypePointers const*) { m_type = make_shared<TypeType>(Type::fromElementaryTypeName(m_typeToken)); } -void Literal::checkTypeRequirements() +void Literal::checkTypeRequirements(TypePointers const*) { m_type = Type::forLiteral(*this); if (!m_type) @@ -282,8 +282,14 @@ public: FunctionDefinition const* getFallbackFunction() const; private: + /// Checks that two functions defined in this contract with the same name have different + /// arguments and that there is at most one constructor. + void checkDuplicateFunctions() const; void checkIllegalOverrides() const; void checkAbstractFunctions(); + /// Checks that different functions with external visibility end up having different + /// external argument types (i.e. different signature). + void checkExternalTypeClashes() const; std::vector<std::pair<FixedHash<4>, FunctionTypePointer>> const& getInterfaceFunctionList() const; @@ -967,7 +973,10 @@ class Expression: public ASTNode { public: Expression(SourceLocation const& _location): ASTNode(_location) {} - virtual void checkTypeRequirements() = 0; + /// Performs type checking after which m_type should be set. + /// @arg _argumentTypes if set, provides the argument types for the case that this expression + /// is used in the context of a call, used for function overload resolution. + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) = 0; std::shared_ptr<Type const> const& getType() const { return m_type; } bool isLValue() const { return m_isLValue; } @@ -1006,7 +1015,7 @@ public: } virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; Expression const& getLeftHandSide() const { return *m_leftHandSide; } Token::Value getAssignmentOperator() const { return m_assigmentOperator; } @@ -1034,7 +1043,7 @@ public: } virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; Token::Value getOperator() const { return m_operator; } bool isPrefixOperation() const { return m_isPrefix; } @@ -1061,7 +1070,7 @@ public: } virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; Expression const& getLeftExpression() const { return *m_left; } Expression const& getRightExpression() const { return *m_right; } @@ -1089,7 +1098,7 @@ public: Expression(_location), m_expression(_expression), m_arguments(_arguments), m_names(_names) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; Expression const& getExpression() const { return *m_expression; } std::vector<ASTPointer<Expression const>> getArguments() const { return {m_arguments.begin(), m_arguments.end()}; } @@ -1115,7 +1124,7 @@ public: Expression(_location), m_contractName(_contractName) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; /// Returns the referenced contract. Can only be called after type checking. ContractDefinition const* getContract() const { solAssert(m_contract, ""); return m_contract; } @@ -1139,11 +1148,18 @@ public: virtual void accept(ASTConstVisitor& _visitor) const override; Expression const& getExpression() const { return *m_expression; } ASTString const& getMemberName() const { return *m_memberName; } - virtual void checkTypeRequirements() override; + /// @returns the declaration referenced by this expression. Might return nullptr even if the + /// expression is valid, e.g. if the member does not correspond to an AST node. + Declaration const* referencedDeclaration() const { return m_referencedDeclaration; } + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; private: ASTPointer<Expression> m_expression; ASTPointer<ASTString> m_memberName; + + /// Pointer to the referenced declaration, this is sometimes needed to resolve function over + /// loads in the type-checking phase. + Declaration const* m_referencedDeclaration = nullptr; }; /** @@ -1157,7 +1173,7 @@ public: Expression(_location), m_base(_base), m_index(_index) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; Expression const& getBaseExpression() const { return *m_base; } Expression const* getIndexExpression() const { return m_index.get(); } @@ -1187,28 +1203,33 @@ public: PrimaryExpression(_location), m_name(_name) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; ASTString const& getName() const { return *m_name; } - void setReferencedDeclaration(Declaration const& _referencedDeclaration, - ContractDefinition const* _currentContract = nullptr) + void setReferencedDeclaration( + Declaration const& _referencedDeclaration, + ContractDefinition const* _currentContract = nullptr + ) { m_referencedDeclaration = &_referencedDeclaration; m_currentContract = _currentContract; } - Declaration const* getReferencedDeclaration() const { return m_referencedDeclaration; } - ContractDefinition const* getCurrentContract() const { return m_currentContract; } + Declaration const& getReferencedDeclaration() const; - void setOverloadedDeclarations(std::set<Declaration const*> const& _declarations) { m_overloadedDeclarations = _declarations; } - std::set<Declaration const*> getOverloadedDeclarations() const { return m_overloadedDeclarations; } + /// Stores a set of possible declarations referenced by this identifier. Has to be resolved + /// providing argument types using overloadResolution before the referenced declaration + /// is accessed. + void setOverloadedDeclarations(std::set<Declaration const*> const& _declarations) + { + m_overloadedDeclarations = _declarations; + } - void checkTypeRequirementsWithFunctionCall(FunctionCall const& _functionCall); - void checkTypeRequirementsFromVariableDeclaration(); + /// Tries to find exactly one of the possible referenced declarations provided the given + /// argument types in a call context. + void overloadResolution(TypePointers const& _argumentTypes); - Declaration const* overloadResolution(FunctionCall const& _functionCall); private: - ASTPointer<ASTString> m_name; /// Declaration the name refers to. @@ -1235,7 +1256,7 @@ public: } virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; Token::Value getTypeToken() const { return m_typeToken; } @@ -1269,7 +1290,7 @@ public: PrimaryExpression(_location), m_token(_token), m_value(_value), m_subDenomination(_sub) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; - virtual void checkTypeRequirements() override; + virtual void checkTypeRequirements(TypePointers const* _argumentTypes) override; Token::Value getToken() const { return m_token; } /// @returns the non-parsed value of the literal diff --git a/Compiler.cpp b/Compiler.cpp index 8e263449..a8d0a591 100644 --- a/Compiler.cpp +++ b/Compiler.cpp @@ -90,7 +90,7 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp for (auto const& modifier: constructor->getModifiers()) { auto baseContract = dynamic_cast<ContractDefinition const*>( - modifier->getName()->getReferencedDeclaration()); + &modifier->getName()->getReferencedDeclaration()); if (baseContract) if (m_baseArguments.count(baseContract->getConstructor()) == 0) m_baseArguments[baseContract->getConstructor()] = &modifier->getArguments(); @@ -99,7 +99,7 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts()) { ContractDefinition const* baseContract = dynamic_cast<ContractDefinition const*>( - base->getName()->getReferencedDeclaration()); + &base->getName()->getReferencedDeclaration()); solAssert(baseContract, ""); if (m_baseArguments.count(baseContract->getConstructor()) == 0) @@ -545,7 +545,7 @@ void Compiler::appendModifierOrFunctionCode() ASTPointer<ModifierInvocation> const& modifierInvocation = m_currentFunction->getModifiers()[m_modifierDepth]; // constructor call should be excluded - if (dynamic_cast<ContractDefinition const*>(modifierInvocation->getName()->getReferencedDeclaration())) + if (dynamic_cast<ContractDefinition const*>(&modifierInvocation->getName()->getReferencedDeclaration())) { ++m_modifierDepth; appendModifierOrFunctionCode(); diff --git a/CompilerContext.cpp b/CompilerContext.cpp index 0afda136..f373fdfb 100644 --- a/CompilerContext.cpp +++ b/CompilerContext.cpp @@ -102,27 +102,13 @@ eth::AssemblyItem CompilerContext::getFunctionEntryLabel(Declaration const& _dec eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) { solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); - for (ContractDefinition const* contract: m_inheritanceHierarchy) - for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) - { - if (!function->isConstructor() && - dynamic_cast<FunctionType const&>(*function->getType(contract)).getCanonicalSignature() == - dynamic_cast<FunctionType const&>(*_function.getType(contract)).getCanonicalSignature()) - return getFunctionEntryLabel(*function); - } - solAssert(false, "Virtual function " + _function.getName() + " not found."); - return m_asm.newTag(); // not reached + return getVirtualFunctionEntryLabel(_function, m_inheritanceHierarchy.begin()); } -eth::AssemblyItem CompilerContext::getSuperFunctionEntryLabel(string const& _name, ContractDefinition const& _base) +eth::AssemblyItem CompilerContext::getSuperFunctionEntryLabel(FunctionDefinition const& _function, ContractDefinition const& _base) { - auto it = getSuperContract(_base); - for (; it != m_inheritanceHierarchy.end(); ++it) - for (ASTPointer<FunctionDefinition> const& function: (*it)->getDefinedFunctions()) - if (!function->isConstructor() && function->getName() == _name) // TODO: add a test case for this! - return getFunctionEntryLabel(*function); - solAssert(false, "Super function " + _name + " not found."); - return m_asm.newTag(); // not reached + solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); + return getVirtualFunctionEntryLabel(_function, getSuperContract(_base)); } FunctionDefinition const* CompilerContext::getNextConstructor(ContractDefinition const& _contract) const @@ -194,6 +180,26 @@ void CompilerContext::resetVisitedNodes(ASTNode const* _node) updateSourceLocation(); } +eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel( + FunctionDefinition const& _function, + vector<ContractDefinition const*>::const_iterator _searchStart +) +{ + string name = _function.getName(); + FunctionType functionType(_function); + auto it = _searchStart; + for (; it != m_inheritanceHierarchy.end(); ++it) + for (ASTPointer<FunctionDefinition> const& function: (*it)->getDefinedFunctions()) + if ( + function->getName() == name && + !function->isConstructor() && + FunctionType(*function).hasEqualArgumentTypes(functionType) + ) + return getFunctionEntryLabel(*function); + solAssert(false, "Super function " + name + " not found."); + return m_asm.newTag(); // not reached +} + vector<ContractDefinition const*>::const_iterator CompilerContext::getSuperContract(ContractDefinition const& _contract) const { solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); diff --git a/CompilerContext.h b/CompilerContext.h index 87f90d4c..67dd3c94 100644 --- a/CompilerContext.h +++ b/CompilerContext.h @@ -63,9 +63,9 @@ public: void setInheritanceHierarchy(std::vector<ContractDefinition const*> const& _hierarchy) { m_inheritanceHierarchy = _hierarchy; } /// @returns the entry label of the given function and takes overrides into account. eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function); - /// @returns the entry label of function with the given name from the most derived class just + /// @returns the entry label of a function that overrides the given declaration from the most derived class just /// above _base in the current inheritance hierarchy. - eth::AssemblyItem getSuperFunctionEntryLabel(std::string const& _name, ContractDefinition const& _base); + eth::AssemblyItem getSuperFunctionEntryLabel(FunctionDefinition const& _function, ContractDefinition const& _base); FunctionDefinition const* getNextConstructor(ContractDefinition const& _contract) const; /// @returns the set of functions for which we still need to generate code @@ -136,6 +136,13 @@ public: }; private: + /// @returns the entry label of the given function - searches the inheritance hierarchy + /// startig from the given point towards the base. + eth::AssemblyItem getVirtualFunctionEntryLabel( + FunctionDefinition const& _function, + std::vector<ContractDefinition const*>::const_iterator _searchStart + ); + /// @returns an iterator to the contract directly above the given contract. std::vector<ContractDefinition const*>::const_iterator getSuperContract(const ContractDefinition &_contract) const; /// Updates source location set in the assembly. void updateSourceLocation(); diff --git a/DeclarationContainer.cpp b/DeclarationContainer.cpp index 226b9d68..565f71df 100644 --- a/DeclarationContainer.cpp +++ b/DeclarationContainer.cpp @@ -35,30 +35,29 @@ bool DeclarationContainer::registerDeclaration(Declaration const& _declaration, if (name.empty()) return true; - if (!_update) + if (_update) + { + solAssert(!dynamic_cast<FunctionDefinition const*>(&_declaration), "Attempt to update function definition."); + m_declarations[name].clear(); + m_invisibleDeclarations[name].clear(); + } + else { if (dynamic_cast<FunctionDefinition const*>(&_declaration)) { - // other declarations must be FunctionDefinition, otherwise clash with other declarations. - for (auto&& declaration: m_declarations[_declaration.getName()]) - if (dynamic_cast<FunctionDefinition const*>(declaration) == nullptr) + // check that all other declarations with the same name are functions + for (auto&& declaration: m_invisibleDeclarations[name] + m_declarations[name]) + if (!dynamic_cast<FunctionDefinition const*>(declaration)) return false; } - else if (m_declarations.count(_declaration.getName()) != 0) - return false; - } - else - { - // update declaration - solAssert(dynamic_cast<FunctionDefinition const*>(&_declaration) == nullptr, "cannot be FunctionDefinition"); - - m_declarations[_declaration.getName()].clear(); + else if (m_declarations.count(name) > 0 || m_invisibleDeclarations.count(name) > 0) + return false; } if (_invisible) - m_invisibleDeclarations.insert(name); + m_invisibleDeclarations[name].insert(&_declaration); else - m_declarations[_declaration.getName()].insert(&_declaration); + m_declarations[name].insert(&_declaration); return true; } diff --git a/DeclarationContainer.h b/DeclarationContainer.h index 42784ec2..35a6ea07 100644 --- a/DeclarationContainer.h +++ b/DeclarationContainer.h @@ -56,7 +56,7 @@ private: Declaration const* m_enclosingDeclaration; DeclarationContainer const* m_enclosingContainer; std::map<ASTString, std::set<Declaration const*>> m_declarations; - std::set<ASTString> m_invisibleDeclarations; + std::map<ASTString, std::set<Declaration const*>> m_invisibleDeclarations; }; } diff --git a/ExpressionCompiler.cpp b/ExpressionCompiler.cpp index 3ca8de89..8c8c3ee0 100644 --- a/ExpressionCompiler.cpp +++ b/ExpressionCompiler.cpp @@ -601,13 +601,25 @@ void ExpressionCompiler::endVisit(MemberAccess const& _memberAccess) bool alsoSearchInteger = false; ContractType const& type = dynamic_cast<ContractType const&>(*_memberAccess.getExpression().getType()); if (type.isSuper()) - m_context << m_context.getSuperFunctionEntryLabel(member, type.getContractDefinition()).pushTag(); + { + solAssert(!!_memberAccess.referencedDeclaration(), "Referenced declaration not resolved."); + m_context << m_context.getSuperFunctionEntryLabel( + dynamic_cast<FunctionDefinition const&>(*_memberAccess.referencedDeclaration()), + type.getContractDefinition() + ).pushTag(); + } else { // ordinary contract type - u256 identifier = type.getFunctionIdentifier(member); - if (identifier != Invalid256) + if (Declaration const* declaration = _memberAccess.referencedDeclaration()) { + u256 identifier; + if (auto const* variable = dynamic_cast<VariableDeclaration const*>(declaration)) + identifier = FunctionType(*variable).externalIdentifier(); + else if (auto const* function = dynamic_cast<FunctionDefinition const*>(declaration)) + identifier = FunctionType(*function).externalIdentifier(); + else + solAssert(false, "Contract member is neither variable nor function."); appendTypeConversion(type, IntegerType(0, IntegerType::Modifier::Address), true); m_context << identifier; } @@ -683,19 +695,16 @@ void ExpressionCompiler::endVisit(MemberAccess const& _memberAccess) case Type::Category::TypeType: { TypeType const& type = dynamic_cast<TypeType const&>(*_memberAccess.getExpression().getType()); - if (!type.getMembers().getMemberType(member)) - BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Invalid member access to " + type.toString())); + solAssert( + !type.getMembers().membersByName(_memberAccess.getMemberName()).empty(), + "Invalid member access to " + type.toString() + ); - if (auto contractType = dynamic_cast<ContractType const*>(type.getActualType().get())) + if (dynamic_cast<ContractType const*>(type.getActualType().get())) { - ContractDefinition const& contract = contractType->getContractDefinition(); - for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions()) - if (function->getName() == member) - { - m_context << m_context.getFunctionEntryLabel(*function).pushTag(); - return; - } - solAssert(false, "Function not found in member access."); + auto const* function = dynamic_cast<FunctionDefinition const*>(_memberAccess.referencedDeclaration()); + solAssert(!!function, "Function not found in member access"); + m_context << m_context.getFunctionEntryLabel(*function).pushTag(); } else if (auto enumType = dynamic_cast<EnumType const*>(type.getActualType().get())) m_context << enumType->getMemberValue(_memberAccess.getMemberName()); @@ -780,7 +789,7 @@ bool ExpressionCompiler::visit(IndexAccess const& _indexAccess) void ExpressionCompiler::endVisit(Identifier const& _identifier) { CompilerContext::LocationSetter locationSetter(m_context, _identifier); - Declaration const* declaration = _identifier.getReferencedDeclaration(); + Declaration const* declaration = &_identifier.getReferencedDeclaration(); if (MagicVariableDeclaration const* magicVar = dynamic_cast<MagicVariableDeclaration const*>(declaration)) { switch (magicVar->getType()->getCategory()) @@ -819,13 +828,6 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier) { // no-op } - else if (declaration == nullptr && _identifier.getOverloadedDeclarations().size() > 1) - { - // var x = f; - declaration = *_identifier.getOverloadedDeclarations().begin(); - FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration); - m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag(); - } else { BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Identifier type not expected in expression context.")); @@ -177,10 +177,10 @@ void StorageItem::storeValue(Type const& _sourceType, SourceLocation const& _loc for (auto const& member: structType.getMembers()) { // assign each member that is not a mapping - TypePointer const& memberType = member.second; + TypePointer const& memberType = member.type; if (memberType->getCategory() == Type::Category::Mapping) continue; - pair<u256, unsigned> const& offsets = structType.getStorageOffsetsOfMember(member.first); + pair<u256, unsigned> const& offsets = structType.getStorageOffsetsOfMember(member.name); m_context << offsets.first << u256(offsets.second) << eth::Instruction::DUP6 << eth::Instruction::DUP3 @@ -230,10 +230,10 @@ void StorageItem::setToZero(SourceLocation const&, bool _removeReference) const for (auto const& member: structType.getMembers()) { // zero each member that is not a mapping - TypePointer const& memberType = member.second; + TypePointer const& memberType = member.type; if (memberType->getCategory() == Type::Category::Mapping) continue; - pair<u256, unsigned> const& offsets = structType.getStorageOffsetsOfMember(member.first); + pair<u256, unsigned> const& offsets = structType.getStorageOffsetsOfMember(member.name); m_context << offsets.first << eth::Instruction::DUP3 << eth::Instruction::ADD << u256(offsets.second); diff --git a/NameAndTypeResolver.cpp b/NameAndTypeResolver.cpp index c787ae6b..1c527b89 100644 --- a/NameAndTypeResolver.cpp +++ b/NameAndTypeResolver.cpp @@ -53,8 +53,9 @@ void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) m_currentScope = &m_scopes[&_contract]; linearizeBaseContracts(_contract); + // we first import non-functions only as we do not yet know the argument types for (ContractDefinition const* base: _contract.getLinearizedBaseContracts()) - importInheritedScope(*base); + importInheritedScope(*base, false); // import non-functions for (ASTPointer<StructDefinition> const& structDef: _contract.getDefinedStructs()) ReferencesResolver resolver(*structDef, *this, &_contract, nullptr); @@ -64,6 +65,8 @@ void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) ReferencesResolver resolver(*variable, *this, &_contract, nullptr); for (ASTPointer<EventDefinition> const& event: _contract.getEvents()) ReferencesResolver resolver(*event, *this, &_contract, nullptr); + + // these can contain code, only resolve parameters for now for (ASTPointer<ModifierDefinition> const& modifier: _contract.getFunctionModifiers()) { m_currentScope = &m_scopes[modifier.get()]; @@ -75,6 +78,28 @@ void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) ReferencesResolver referencesResolver(*function, *this, &_contract, function->getReturnParameterList().get()); } + + m_currentScope = &m_scopes[&_contract]; + for (ContractDefinition const* base: _contract.getLinearizedBaseContracts()) + importInheritedScope(*base, true); // import functions + + // now resolve references inside the code + for (ASTPointer<ModifierDefinition> const& modifier: _contract.getFunctionModifiers()) + { + m_currentScope = &m_scopes[modifier.get()]; + ReferencesResolver resolver(*modifier, *this, &_contract, nullptr, true); + } + for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) + { + m_currentScope = &m_scopes[function.get()]; + ReferencesResolver referencesResolver( + *function, + *this, + &_contract, + function->getReturnParameterList().get(), + true + ); + } } void NameAndTypeResolver::checkTypeRequirements(ContractDefinition& _contract) @@ -90,7 +115,7 @@ void NameAndTypeResolver::updateDeclaration(Declaration const& _declaration) solAssert(_declaration.getScope() == nullptr, "Updated declaration outside global scope."); } -std::set<Declaration const*> NameAndTypeResolver::resolveName(ASTString const& _name, Declaration const* _scope) const +set<Declaration const*> NameAndTypeResolver::resolveName(ASTString const& _name, Declaration const* _scope) const { auto iterator = m_scopes.find(_scope); if (iterator == end(m_scopes)) @@ -98,21 +123,43 @@ std::set<Declaration const*> NameAndTypeResolver::resolveName(ASTString const& _ return iterator->second.resolveName(_name, false); } -std::set<Declaration const*> NameAndTypeResolver::getNameFromCurrentScope(ASTString const& _name, bool _recursive) +set<Declaration const*> NameAndTypeResolver::getNameFromCurrentScope(ASTString const& _name, bool _recursive) { return m_currentScope->resolveName(_name, _recursive); } -void NameAndTypeResolver::importInheritedScope(ContractDefinition const& _base) +void NameAndTypeResolver::importInheritedScope(ContractDefinition const& _base, bool _importFunctions) { auto iterator = m_scopes.find(&_base); solAssert(iterator != end(m_scopes), ""); for (auto const& nameAndDeclaration: iterator->second.getDeclarations()) for (auto const& declaration: nameAndDeclaration.second) // Import if it was declared in the base, is not the constructor and is visible in derived classes - if (declaration->getScope() == &_base && declaration->getName() != _base.getName() && - declaration->isVisibleInDerivedContracts()) + if (declaration->getScope() == &_base && declaration->isVisibleInDerivedContracts()) + { + auto function = dynamic_cast<FunctionDefinition const*>(declaration); + if ((function == nullptr) == _importFunctions) + continue; + if (!!function) + { + FunctionType functionType(*function); + // only import if a function with the same arguments does not exist yet + bool functionWithEqualArgumentsFound = false; + for (auto knownDeclaration: m_currentScope->resolveName(nameAndDeclaration.first)) + { + auto knownFunction = dynamic_cast<FunctionDefinition const*>(knownDeclaration); + if (!knownFunction) + continue; // this is not legal, but will be caught later + if (!FunctionType(*knownFunction).hasEqualArgumentTypes(functionType)) + continue; + functionWithEqualArgumentsFound = true; + break; + } + if (functionWithEqualArgumentsFound) + continue; + } m_currentScope->registerDeclaration(*declaration); + } } void NameAndTypeResolver::linearizeBaseContracts(ContractDefinition& _contract) const @@ -123,8 +170,7 @@ void NameAndTypeResolver::linearizeBaseContracts(ContractDefinition& _contract) for (ASTPointer<InheritanceSpecifier> const& baseSpecifier: _contract.getBaseContracts()) { ASTPointer<Identifier> baseName = baseSpecifier->getName(); - ContractDefinition const* base = dynamic_cast<ContractDefinition const*>( - baseName->getReferencedDeclaration()); + auto base = dynamic_cast<ContractDefinition const*>(&baseName->getReferencedDeclaration()); if (!base) BOOST_THROW_EXCEPTION(baseName->createTypeError("Contract expected.")); // "push_front" has the effect that bases mentioned later can overwrite members of bases @@ -316,11 +362,19 @@ void DeclarationRegistrationHelper::registerDeclaration(Declaration& _declaratio enterNewSubScope(_declaration); } -ReferencesResolver::ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, - ContractDefinition const* _currentContract, - ParameterList const* _returnParameters, bool _allowLazyTypes): - m_resolver(_resolver), m_currentContract(_currentContract), - m_returnParameters(_returnParameters), m_allowLazyTypes(_allowLazyTypes) +ReferencesResolver::ReferencesResolver( + ASTNode& _root, + NameAndTypeResolver& _resolver, + ContractDefinition const* _currentContract, + ParameterList const* _returnParameters, + bool _resolveInsideCode, + bool _allowLazyTypes +): + m_resolver(_resolver), + m_currentContract(_currentContract), + m_returnParameters(_returnParameters), + m_resolveInsideCode(_resolveInsideCode), + m_allowLazyTypes(_allowLazyTypes) { _root.accept(*this); } diff --git a/NameAndTypeResolver.h b/NameAndTypeResolver.h index 82877617..6528bbef 100644 --- a/NameAndTypeResolver.h +++ b/NameAndTypeResolver.h @@ -65,9 +65,10 @@ public: private: void reset(); - /// Imports all members declared directly in the given contract (i.e. does not import inherited - /// members) into the current scope if they are not present already. - void importInheritedScope(ContractDefinition const& _base); + /// Either imports all non-function members or all function members declared directly in the + /// given contract (i.e. does not import inherited members) into the current scope if they are + ///not present already. + void importInheritedScope(ContractDefinition const& _base, bool _importFunctions); /// Computes "C3-Linearization" of base contracts and stores it inside the contract. void linearizeBaseContracts(ContractDefinition& _contract) const; @@ -126,13 +127,18 @@ private: class ReferencesResolver: private ASTVisitor { public: - ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, - ContractDefinition const* _currentContract, - ParameterList const* _returnParameters, - bool _allowLazyTypes = true); + ReferencesResolver( + ASTNode& _root, + NameAndTypeResolver& _resolver, + ContractDefinition const* _currentContract, + ParameterList const* _returnParameters, + bool _resolveInsideCode = false, + bool _allowLazyTypes = true + ); private: virtual void endVisit(VariableDeclaration& _variable) override; + virtual bool visit(Block&) override { return m_resolveInsideCode; } virtual bool visit(Identifier& _identifier) override; virtual bool visit(UserDefinedTypeName& _typeName) override; virtual bool visit(Mapping&) override; @@ -141,6 +147,7 @@ private: NameAndTypeResolver& m_resolver; ContractDefinition const* m_currentContract; ParameterList const* m_returnParameters; + bool m_resolveInsideCode; bool m_allowLazyTypes; }; @@ -25,6 +25,7 @@ #include <boost/range/adaptor/reversed.hpp> #include <libdevcore/CommonIO.h> #include <libdevcore/CommonData.h> +#include <libdevcrypto/SHA3.h> #include <libsolidity/Utils.h> #include <libsolidity/AST.h> @@ -92,13 +93,13 @@ std::pair<u256, unsigned> const* MemberList::getMemberStorageOffset(string const { TypePointers memberTypes; memberTypes.reserve(m_memberTypes.size()); - for (auto const& nameAndType: m_memberTypes) - memberTypes.push_back(nameAndType.second); + for (auto const& member: m_memberTypes) + memberTypes.push_back(member.type); m_storageOffsets.reset(new StorageOffsets()); m_storageOffsets->computeOffsets(memberTypes); } for (size_t index = 0; index < m_memberTypes.size(); ++index) - if (m_memberTypes[index].first == _name) + if (m_memberTypes[index].name == _name) return m_storageOffsets->getOffset(index); return nullptr; } @@ -189,7 +190,7 @@ TypePointer Type::fromArrayTypeName(TypeName& _baseTypeName, Expression* _length if (_length) { if (!_length->getType()) - _length->checkTypeRequirements(); + _length->checkTypeRequirements(nullptr); auto const* length = dynamic_cast<IntegerConstantType const*>(_length->getType().get()); if (!length) BOOST_THROW_EXCEPTION(_length->createTypeError("Invalid array length.")); @@ -793,18 +794,46 @@ MemberList const& ContractType::getMembers() const if (!m_members) { // All address members and all interface functions - vector<pair<string, TypePointer>> members(IntegerType::AddressMemberList.begin(), - IntegerType::AddressMemberList.end()); + MemberList::MemberMap members( + IntegerType::AddressMemberList.begin(), + IntegerType::AddressMemberList.end() + ); if (m_super) { + // add the most derived of all functions which are visible in derived contracts for (ContractDefinition const* base: m_contract.getLinearizedBaseContracts()) for (ASTPointer<FunctionDefinition> const& function: base->getDefinedFunctions()) - if (function->isVisibleInDerivedContracts()) - members.push_back(make_pair(function->getName(), make_shared<FunctionType>(*function, true))); + { + if (!function->isVisibleInDerivedContracts()) + continue; + auto functionType = make_shared<FunctionType>(*function, true); + bool functionWithEqualArgumentsFound = false; + for (auto const& member: members) + { + if (member.name != function->getName()) + continue; + auto memberType = dynamic_cast<FunctionType const*>(member.type.get()); + solAssert(!!memberType, "Override changes type."); + if (!memberType->hasEqualArgumentTypes(*functionType)) + continue; + functionWithEqualArgumentsFound = true; + break; + } + if (!functionWithEqualArgumentsFound) + members.push_back(MemberList::Member( + function->getName(), + functionType, + function.get() + )); + } } else for (auto const& it: m_contract.getInterfaceFunctions()) - members.push_back(make_pair(it.second->getDeclaration().getName(), it.second)); + members.push_back(MemberList::Member( + it.second->getDeclaration().getName(), + it.second, + &it.second->getDeclaration() + )); m_members.reset(new MemberList(members)); } return *m_members; @@ -823,16 +852,6 @@ shared_ptr<FunctionType const> const& ContractType::getConstructorType() const return m_constructorType; } -u256 ContractType::getFunctionIdentifier(string const& _functionName) const -{ - auto interfaceFunctions = m_contract.getInterfaceFunctions(); - for (auto const& it: m_contract.getInterfaceFunctions()) - if (it.second->getDeclaration().getName() == _functionName) - return FixedHash<4>::Arith(it.first); - - return Invalid256; -} - vector<tuple<VariableDeclaration const*, u256, unsigned>> ContractType::getStateVariables() const { vector<VariableDeclaration const*> variables; @@ -873,8 +892,8 @@ u256 StructType::getStorageSize() const bool StructType::canLiveOutsideStorage() const { - for (pair<string, TypePointer> const& member: getMembers()) - if (!member.second->canLiveOutsideStorage()) + for (auto const& member: getMembers()) + if (!member.type->canLiveOutsideStorage()) return false; return true; } @@ -891,7 +910,7 @@ MemberList const& StructType::getMembers() const { MemberList::MemberMap members; for (ASTPointer<VariableDeclaration> const& variable: m_struct.getMembers()) - members.push_back(make_pair(variable->getName(), variable->getType())); + members.push_back(MemberList::Member(variable->getName(), variable->getType(), variable.get())); m_members.reset(new MemberList(members)); } return *m_members; @@ -996,11 +1015,11 @@ FunctionType::FunctionType(VariableDeclaration const& _varDecl): vector<string> retParamNames; if (auto structType = dynamic_cast<StructType const*>(returnType.get())) { - for (pair<string, TypePointer> const& member: structType->getMembers()) - if (member.second->canLiveOutsideStorage()) + for (auto const& member: structType->getMembers()) + if (member.type->canLiveOutsideStorage()) { - retParamNames.push_back(member.first); - retParams.push_back(member.second); + retParamNames.push_back(member.name); + retParams.push_back(member.type); } } else @@ -1130,12 +1149,12 @@ MemberList const& FunctionType::getMembers() const case Location::Bare: if (!m_members) { - vector<pair<string, TypePointer>> members{ + MemberList::MemberMap members{ {"value", make_shared<FunctionType>(parseElementaryTypeVector({"uint"}), TypePointers{copyAndSetGasOrValue(false, true)}, Location::SetValue, false, m_gasSet, m_valueSet)}}; if (m_location != Location::Creation) - members.push_back(make_pair("gas", make_shared<FunctionType>( + members.push_back(MemberList::Member("gas", make_shared<FunctionType>( parseElementaryTypeVector({"uint"}), TypePointers{copyAndSetGasOrValue(true, false)}, Location::SetGas, false, m_gasSet, m_valueSet))); @@ -1147,6 +1166,37 @@ MemberList const& FunctionType::getMembers() const } } +bool FunctionType::canTakeArguments(TypePointers const& _argumentTypes) const +{ + TypePointers const& parameterTypes = getParameterTypes(); + if (takesArbitraryParameters()) + return true; + else if (_argumentTypes.size() != parameterTypes.size()) + return false; + else + return std::equal( + _argumentTypes.cbegin(), + _argumentTypes.cend(), + parameterTypes.cbegin(), + [](TypePointer const& argumentType, TypePointer const& parameterType) + { + return argumentType->isImplicitlyConvertibleTo(*parameterType); + } + ); +} + +bool FunctionType::hasEqualArgumentTypes(FunctionType const& _other) const +{ + if (m_parameterTypes.size() != _other.m_parameterTypes.size()) + return false; + return equal( + m_parameterTypes.cbegin(), + m_parameterTypes.cend(), + _other.m_parameterTypes.cbegin(), + [](TypePointer const& _a, TypePointer const& _b) -> bool { return *_a == *_b; } + ); +} + string FunctionType::externalSignature(std::string const& _name) const { std::string funcName = _name; @@ -1167,6 +1217,11 @@ string FunctionType::externalSignature(std::string const& _name) const return ret + ")"; } +u256 FunctionType::externalIdentifier() const +{ + return FixedHash<4>::Arith(FixedHash<4>(dev::sha3(externalSignature()))); +} + TypePointers FunctionType::parseElementaryTypeVector(strings const& _types) { TypePointers pointers; @@ -1250,7 +1305,7 @@ MemberList const& TypeType::getMembers() const // We need to lazy-initialize it because of recursive references. if (!m_members) { - vector<pair<string, TypePointer>> members; + MemberList::MemberMap members; if (m_actualType->getCategory() == Category::Contract && m_currentContract != nullptr) { ContractDefinition const& contract = dynamic_cast<ContractType const&>(*m_actualType).getContractDefinition(); @@ -1259,14 +1314,14 @@ MemberList const& TypeType::getMembers() const // We are accessing the type of a base contract, so add all public and protected // members. Note that this does not add inherited functions on purpose. for (Declaration const* decl: contract.getInheritableMembers()) - members.push_back(make_pair(decl->getName(), decl->getType())); + members.push_back(MemberList::Member(decl->getName(), decl->getType(), decl)); } else if (m_actualType->getCategory() == Category::Enum) { EnumDefinition const& enumDef = dynamic_cast<EnumType const&>(*m_actualType).getEnumDefinition(); auto enumType = make_shared<EnumType>(enumDef); for (ASTPointer<EnumValue> const& enumValue: enumDef.getMembers()) - members.push_back(make_pair(enumValue->getName(), enumType)); + members.push_back(MemberList::Member(enumValue->getName(), enumType)); } m_members.reset(new MemberList(members)); } @@ -69,17 +69,43 @@ private: class MemberList { public: - using MemberMap = std::vector<std::pair<std::string, TypePointer>>; + struct Member + { + Member(std::string const& _name, TypePointer const& _type, Declaration const* _declaration = nullptr): + name(_name), + type(_type), + declaration(_declaration) + { + } + + std::string name; + TypePointer type; + Declaration const* declaration = nullptr; + }; + + using MemberMap = std::vector<Member>; MemberList() {} explicit MemberList(MemberMap const& _members): m_memberTypes(_members) {} MemberList& operator=(MemberList&& _other); TypePointer getMemberType(std::string const& _name) const { + TypePointer type; for (auto const& it: m_memberTypes) - if (it.first == _name) - return it.second; - return TypePointer(); + if (it.name == _name) + { + solAssert(!type, "Requested member type by non-unique name."); + type = it.type; + } + return type; + } + MemberMap membersByName(std::string const& _name) const + { + MemberMap members; + for (auto const& it: m_memberTypes) + if (it.name == _name) + members.push_back(it); + return members; } /// @returns the offset of the given member in storage slots and bytes inside a slot or /// a nullptr if the member is not part of storage. @@ -104,7 +130,7 @@ public: enum class Category { Integer, IntegerConstant, Bool, Real, Array, - FixedBytes, Contract, Struct, Function, OverloadedFunctions, Enum, + FixedBytes, Contract, Struct, Function, Enum, Mapping, Void, TypeType, Modifier, Magic }; @@ -554,11 +580,18 @@ public: virtual unsigned getSizeOnStack() const override; virtual MemberList const& getMembers() const override; + /// @returns true if this function can take the given argument types (possibly + /// after implicit conversion). + bool canTakeArguments(TypePointers const& _arguments) const; + bool hasEqualArgumentTypes(FunctionType const& _other) const; + Location const& getLocation() const { return m_location; } /// @returns the external signature of this function type given the function name /// If @a _name is not provided (empty string) then the @c m_declaration member of the /// function type is used std::string externalSignature(std::string const& _name = "") const; + /// @returns the external identifier of this function (the hash of the signature). + u256 externalIdentifier() const; Declaration const& getDeclaration() const { solAssert(m_declaration, "Requested declaration from a FunctionType that has none"); @@ -597,20 +630,6 @@ private: Declaration const* m_declaration = nullptr; }; -class OverloadedFunctionType: public Type -{ -public: - explicit OverloadedFunctionType(Identifier* _identifier): m_identifier(_identifier) {} - - virtual Category getCategory() const override { return Category::OverloadedFunctions; } - virtual std::string toString() const override { return "OverloadedFunctions"; } - - Identifier* getIdentifier() const { return m_identifier; } - -private: - Identifier * m_identifier; -}; - /** * The type of a mapping, there is one distinct type per key/value type pair. * Mappings always occupy their own storage slot, but do not actually use it. |