diff options
author | Lu Guanqun <guanqun.lu@gmail.com> | 2015-03-01 11:34:39 +0800 |
---|---|---|
committer | Lu Guanqun <guanqun.lu@gmail.com> | 2015-03-08 22:50:06 +0800 |
commit | 3b9b71e0ae86cc20c6a0201b027bd45bee4257e5 (patch) | |
tree | c3b1ffc4e5d6d2e8013733401289928cfc51d086 | |
parent | e008f3f808b1483e7b7a5861ea4fe46fbe74bcff (diff) | |
download | dexon-solidity-3b9b71e0ae86cc20c6a0201b027bd45bee4257e5.tar.gz dexon-solidity-3b9b71e0ae86cc20c6a0201b027bd45bee4257e5.tar.zst dexon-solidity-3b9b71e0ae86cc20c6a0201b027bd45bee4257e5.zip |
implement overload resolution
-rw-r--r-- | AST.cpp | 159 | ||||
-rw-r--r-- | AST.h | 15 | ||||
-rw-r--r-- | CompilerContext.cpp | 4 | ||||
-rw-r--r-- | ExpressionCompiler.cpp | 6 | ||||
-rw-r--r-- | NameAndTypeResolver.cpp | 39 | ||||
-rw-r--r-- | NameAndTypeResolver.h | 4 | ||||
-rw-r--r-- | Parser.cpp | 7 | ||||
-rw-r--r-- | Types.h | 15 |
8 files changed, 210 insertions, 39 deletions
@@ -76,6 +76,15 @@ void ContractDefinition::checkTypeRequirements() for (ASTPointer<FunctionDefinition> const& function: getDefinedFunctions()) function->checkTypeRequirements(); + // check for duplicate declaration + set<string> functions; + for (ASTPointer<FunctionDefinition> const& function: getDefinedFunctions()) + { + string signature = function->getCanonicalSignature(); + if (functions.count(signature)) + BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_comment("Duplicate functions are not allowed.")); + functions.insert(signature); + } for (ASTPointer<VariableDeclaration> const& variable: m_stateVariables) variable->checkTypeRequirements(); @@ -129,6 +138,7 @@ void ContractDefinition::checkIllegalOverrides() const // TODO unify this at a later point. for this we need to put the constness and the access specifier // into the types map<string, FunctionDefinition const*> functions; + set<string> functionNames; map<string, ModifierDefinition const*> modifiers; // We search from derived to base, so the stored item causes the error. @@ -141,7 +151,8 @@ void ContractDefinition::checkIllegalOverrides() const string const& name = function->getName(); if (modifiers.count(name)) BOOST_THROW_EXCEPTION(modifiers[name]->createTypeError("Override changes function to modifier.")); - FunctionDefinition const*& override = functions[name]; + FunctionDefinition const*& override = functions[function->getCanonicalSignature()]; + functionNames.insert(name); if (!override) override = function.get(); else if (override->getVisibility() != function->getVisibility() || @@ -152,13 +163,13 @@ void ContractDefinition::checkIllegalOverrides() const for (ASTPointer<ModifierDefinition> const& modifier: contract->getFunctionModifiers()) { string const& name = modifier->getName(); - if (functions.count(name)) - BOOST_THROW_EXCEPTION(functions[name]->createTypeError("Override changes modifier to function.")); ModifierDefinition const*& override = modifiers[name]; if (!override) override = modifier.get(); else if (ModifierType(*override) != ModifierType(*modifier)) BOOST_THROW_EXCEPTION(override->createTypeError("Override changes modifier signature.")); + if (functionNames.count(name)) + BOOST_THROW_EXCEPTION(override->createTypeError("Override changes modifier to function.")); } } } @@ -185,16 +196,21 @@ vector<pair<FixedHash<4>, FunctionTypePointer>> const& ContractDefinition::getIn if (!m_interfaceFunctionList) { set<string> functionsSeen; + set<string> signaturesSeen; m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionTypePointer>>()); for (ContractDefinition const* contract: getLinearizedBaseContracts()) { for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions()) - if (f->isPublic() && !f->isConstructor() && !f->getName().empty() && functionsSeen.count(f->getName()) == 0) + { + string functionSignature = f->getCanonicalSignature(); + if (f->isPublic() && !f->isConstructor() && !f->getName().empty() && signaturesSeen.count(functionSignature) == 0) { functionsSeen.insert(f->getName()); - FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); + signaturesSeen.insert(functionSignature); + FixedHash<4> hash(dev::sha3(functionSignature)); m_interfaceFunctionList->push_back(make_pair(hash, make_shared<FunctionType>(*f, false))); } + } for (ASTPointer<VariableDeclaration> const& v: contract->getStateVariables()) if (v->isPublic() && functionsSeen.count(v->getName()) == 0) @@ -467,7 +483,43 @@ void Return::checkTypeRequirements() void VariableDeclarationStatement::checkTypeRequirements() { +<<<<<<< HEAD m_variable->checkTypeRequirements(); +======= + // Variables can be declared without type (with "var"), in which case the first assignment + // sets the type. + // Note that assignments before the first declaration are legal because of the special scoping + // rules inherited from JavaScript. + if (m_variable->getValue()) + { + if (m_variable->getType()) + { + std::cout << "getType() ok" << std::endl; + m_variable->getValue()->expectType(*m_variable->getType()); + } + else + { + // no type declared and no previous assignment, infer the type + std::cout << "here's where called...." << std::endl; + Identifier* identifier = dynamic_cast<Identifier*>(m_variable->getValue().get()); + if (identifier) + identifier->checkTypeRequirementsFromVariableDeclaration(); + else + m_variable->getValue()->checkTypeRequirements(); + TypePointer type = m_variable->getValue()->getType(); + if (type->getCategory() == Type::Category::IntegerConstant) + { + auto intType = dynamic_pointer_cast<IntegerConstantType const>(type)->getIntegerType(); + if (!intType) + BOOST_THROW_EXCEPTION(m_variable->getValue()->createTypeError("Invalid integer constant " + type->toString())); + type = intType; + } + else if (type->getCategory() == Type::Category::Void) + BOOST_THROW_EXCEPTION(m_variable->createTypeError("var cannot be void type")); + m_variable->setType(type); + } + } +>>>>>>> implement overload resolution } void Assignment::checkTypeRequirements() @@ -544,10 +596,16 @@ void BinaryOperation::checkTypeRequirements() void FunctionCall::checkTypeRequirements() { - m_expression->checkTypeRequirements(); + // we need to check arguments' type first as their info will be used by m_express(Identifier). for (ASTPointer<Expression> const& argument: m_arguments) argument->checkTypeRequirements(); + auto identifier = dynamic_cast<Identifier*>(m_expression.get()); + if (identifier) + identifier->checkTypeRequirementsWithFunctionCall(*this); + else + m_expression->checkTypeRequirements(); + Type const* expressionType = m_expression->getType().get(); if (isTypeConversion()) { @@ -617,6 +675,19 @@ void FunctionCall::checkTypeRequirements() else m_type = functionType->getReturnParameterTypes().front(); } + else if (OverloadedFunctionType const* overloadedTypes = dynamic_cast<OverloadedFunctionType const*>(expressionType)) + { + // this only applies to "x(3)" where x is assigned by "var x = f;" where f is an overloaded functions. + overloadedTypes->m_identifier->overloadResolution(*this); + FunctionType const* functionType = dynamic_cast<FunctionType const*>(overloadedTypes->m_identifier->getType().get()); + + // @todo actually the return type should be an anonymous struct, + // but we change it to the type of the first return value until we have structs + if (functionType->getReturnParameterTypes().empty()) + m_type = make_shared<VoidType>(); + else + m_type = functionType->getReturnParameterTypes().front(); + } else BOOST_THROW_EXCEPTION(createTypeError("Type is not callable.")); } @@ -709,16 +780,92 @@ void IndexAccess::checkTypeRequirements() } } +void Identifier::checkTypeRequirementsWithFunctionCall(FunctionCall const& _functionCall) +{ + solAssert(m_referencedDeclaration || !m_overloadedDeclarations.empty(), "Identifier not resolved."); + + if (!m_referencedDeclaration) + overloadResolution(_functionCall); + + checkTypeRequirements(); +} + +void Identifier::checkTypeRequirementsFromVariableDeclaration() +{ + solAssert(m_referencedDeclaration || !m_overloadedDeclarations.empty(), "Identifier not resolved."); + + if (!m_referencedDeclaration) + m_type = make_shared<OverloadedFunctionType>(m_overloadedDeclarations, this); + else + checkTypeRequirements(); + + m_isLValue = true; +} + void Identifier::checkTypeRequirements() { + // var x = f; TODO! solAssert(m_referencedDeclaration, "Identifier not resolved."); m_isLValue = m_referencedDeclaration->isLValue(); + if (m_isLValue) + std::cout << "Identifier: " << string(getName()) << " -> true" << std::endl; + else + std::cout << "Identifier: " << string(getName()) << " -> true" << std::endl; m_type = m_referencedDeclaration->getType(m_currentContract); if (!m_type) BOOST_THROW_EXCEPTION(createTypeError("Declaration referenced before type could be determined.")); } +void Identifier::overloadResolution(FunctionCall const& _functionCall) +{ + solAssert(m_overloadedDeclarations.size() > 1, "FunctionIdentifier not resolved."); + solAssert(!m_referencedDeclaration, "Referenced declaration should be null before overload resolution."); + + bool resolved = false; + + std::vector<ASTPointer<Expression const>> arguments = _functionCall.getArguments(); + std::vector<ASTPointer<ASTString>> const& argumentNames = _functionCall.getNames(); + + if (argumentNames.empty()) + { + // positional arguments + std::vector<Declaration const*> possibles; + for (Declaration const* declaration: m_overloadedDeclarations) + { + TypePointer const& function = declaration->getType(); + auto const& functionType = dynamic_cast<FunctionType const&>(*function); + TypePointers const& parameterTypes = functionType.getParameterTypes(); + + if (functionType.takesArbitraryParameters() || + (arguments.size() == parameterTypes.size() && + std::equal(arguments.cbegin(), arguments.cend(), parameterTypes.cbegin(), + [](ASTPointer<Expression const> const& argument, TypePointer const& parameterType) + { + return argument->getType()->isImplicitlyConvertibleTo(*parameterType); + }))) + possibles.push_back(declaration); + } + std::cout << "possibles: " << possibles.size() << std::endl; + if (possibles.empty()) + BOOST_THROW_EXCEPTION(createTypeError("Can't resolve identifier")); + else if (std::none_of(possibles.cbegin() + 1, possibles.cend(), + [&possibles](Declaration const* declaration) + { + return declaration->getScope() == possibles.front()->getScope(); + })) + setReferencedDeclaration(*possibles.front()); + else + BOOST_THROW_EXCEPTION(createTypeError("Can't resolve identifier")); + } + else + { + // named arguments + // TODO: don't support right now + // BOOST_THROW_EXCEPTION(createTypeError("Named arguments with overloaded functions are not supported yet.")); + } +} + void ElementaryTypeNameExpression::checkTypeRequirements() { m_type = make_shared<TypeType>(Type::fromElementaryTypeName(m_typeToken)); @@ -1134,8 +1134,8 @@ public: class Identifier: public PrimaryExpression { public: - Identifier(SourceLocation const& _location, ASTPointer<ASTString> const& _name, bool _isCallable): - PrimaryExpression(_location), m_name(_name), m_isCallable(_isCallable) {} + Identifier(SourceLocation const& _location, ASTPointer<ASTString> const& _name): + PrimaryExpression(_location), m_name(_name) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; virtual void checkTypeRequirements() override; @@ -1151,9 +1151,15 @@ public: Declaration const* getReferencedDeclaration() const { return m_referencedDeclaration; } ContractDefinition const* getCurrentContract() const { return m_currentContract; } - bool isCallable() const { return m_isCallable; } + void setOverloadedDeclarations(std::set<Declaration const*> const& _declarations) { m_overloadedDeclarations = _declarations; } + std::set<Declaration const*> getOverloadedDeclarations() const { return m_overloadedDeclarations; } + void checkTypeRequirementsWithFunctionCall(FunctionCall const& _functionCall); + void checkTypeRequirementsFromVariableDeclaration(); + + void overloadResolution(FunctionCall const& _functionCall); private: + ASTPointer<ASTString> m_name; /// Declaration the name refers to. @@ -1161,7 +1167,8 @@ private: /// Stores a reference to the current contract. This is needed because types of base contracts /// change depending on the context. ContractDefinition const* m_currentContract = nullptr; - bool m_isCallable = false; + /// A set of overloaded declarations, right now only FunctionDefinition has overloaded declarations. + std::set<Declaration const*> m_overloadedDeclarations; }; /** diff --git a/CompilerContext.cpp b/CompilerContext.cpp index f787db7f..b12e0192 100644 --- a/CompilerContext.cpp +++ b/CompilerContext.cpp @@ -108,8 +108,8 @@ eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefiniti for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) { if (!function->isConstructor() && - dynamic_cast<FunctionType const&>(*function->getType()).getCanonicalSignature() == - dynamic_cast<FunctionType const&>(*_function.getType()).getCanonicalSignature()) + dynamic_cast<FunctionType const&>(*function->getType(contract)).getCanonicalSignature() == + dynamic_cast<FunctionType const&>(*_function.getType(contract)).getCanonicalSignature()) return getFunctionEntryLabel(*function); } solAssert(false, "Virtual function " + _function.getName() + " not found."); diff --git a/ExpressionCompiler.cpp b/ExpressionCompiler.cpp index 3d7a2531..5e5442ba 100644 --- a/ExpressionCompiler.cpp +++ b/ExpressionCompiler.cpp @@ -822,7 +822,11 @@ bool ExpressionCompiler::visit(IndexAccess const& _indexAccess) void ExpressionCompiler::endVisit(Identifier const& _identifier) { Declaration const* declaration = _identifier.getReferencedDeclaration(); - if (MagicVariableDeclaration const* magicVar = dynamic_cast<MagicVariableDeclaration const*>(declaration)) + if (declaration == nullptr) + { + // no-op + } + else if (MagicVariableDeclaration const* magicVar = dynamic_cast<MagicVariableDeclaration const*>(declaration)) { if (magicVar->getType()->getCategory() == Type::Category::Contract) // "this" or "super" diff --git a/NameAndTypeResolver.cpp b/NameAndTypeResolver.cpp index f6ee2f1d..c787ae6b 100644 --- a/NameAndTypeResolver.cpp +++ b/NameAndTypeResolver.cpp @@ -90,15 +90,15 @@ void NameAndTypeResolver::updateDeclaration(Declaration const& _declaration) solAssert(_declaration.getScope() == nullptr, "Updated declaration outside global scope."); } -Declaration const* NameAndTypeResolver::resolveName(ASTString const& _name, Declaration const* _scope) const +std::set<Declaration const*> NameAndTypeResolver::resolveName(ASTString const& _name, Declaration const* _scope) const { auto iterator = m_scopes.find(_scope); if (iterator == end(m_scopes)) - return nullptr; + return std::set<Declaration const*>({}); return iterator->second.resolveName(_name, false); } -Declaration const* NameAndTypeResolver::getNameFromCurrentScope(ASTString const& _name, bool _recursive) +std::set<Declaration const*> NameAndTypeResolver::getNameFromCurrentScope(ASTString const& _name, bool _recursive) { return m_currentScope->resolveName(_name, _recursive); } @@ -108,13 +108,11 @@ void NameAndTypeResolver::importInheritedScope(ContractDefinition const& _base) auto iterator = m_scopes.find(&_base); solAssert(iterator != end(m_scopes), ""); for (auto const& nameAndDeclaration: iterator->second.getDeclarations()) - { - Declaration const* declaration = nameAndDeclaration.second; - // Import if it was declared in the base, is not the constructor and is visible in derived classes - if (declaration->getScope() == &_base && declaration->getName() != _base.getName() && - declaration->isVisibleInDerivedContracts()) - m_currentScope->registerDeclaration(*declaration); - } + for (auto const& declaration: nameAndDeclaration.second) + // Import if it was declared in the base, is not the constructor and is visible in derived classes + if (declaration->getScope() == &_base && declaration->getName() != _base.getName() && + declaration->isVisibleInDerivedContracts()) + m_currentScope->registerDeclaration(*declaration); } void NameAndTypeResolver::linearizeBaseContracts(ContractDefinition& _contract) const @@ -361,24 +359,31 @@ bool ReferencesResolver::visit(Mapping&) bool ReferencesResolver::visit(UserDefinedTypeName& _typeName) { - Declaration const* declaration = m_resolver.getNameFromCurrentScope(_typeName.getName()); - if (!declaration) + auto declarations = m_resolver.getNameFromCurrentScope(_typeName.getName()); + if (declarations.empty()) BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_typeName.getLocation()) << errinfo_comment("Undeclared identifier.")); - _typeName.setReferencedDeclaration(*declaration); + else if (declarations.size() > 1) + BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_typeName.getLocation()) + << errinfo_comment("Duplicate identifier.")); + else + _typeName.setReferencedDeclaration(**declarations.begin()); return false; } bool ReferencesResolver::visit(Identifier& _identifier) { - Declaration const* declaration = m_resolver.getNameFromCurrentScope(_identifier.getName()); - if (!declaration) + auto declarations = m_resolver.getNameFromCurrentScope(_identifier.getName()); + if (declarations.empty()) BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_identifier.getLocation()) << errinfo_comment("Undeclared identifier.")); - _identifier.setReferencedDeclaration(*declaration, m_currentContract); + else if (declarations.size() == 1) + _identifier.setReferencedDeclaration(**declarations.begin(), m_currentContract); + else + // Duplicate declaration will be checked in checkTypeRequirements() + _identifier.setOverloadedDeclarations(declarations); return false; } - } } diff --git a/NameAndTypeResolver.h b/NameAndTypeResolver.h index 63b8ab63..82877617 100644 --- a/NameAndTypeResolver.h +++ b/NameAndTypeResolver.h @@ -56,11 +56,11 @@ public: /// Resolves the given @a _name inside the scope @a _scope. If @a _scope is omitted, /// the global scope is used (i.e. the one containing only the contract). /// @returns a pointer to the declaration on success or nullptr on failure. - Declaration const* resolveName(ASTString const& _name, Declaration const* _scope = nullptr) const; + std::set<Declaration const*> resolveName(ASTString const& _name, Declaration const* _scope = nullptr) const; /// Resolves a name in the "current" scope. Should only be called during the initial /// resolving phase. - Declaration const* getNameFromCurrentScope(ASTString const& _name, bool _recursive = true); + std::set<Declaration const*> getNameFromCurrentScope(ASTString const& _name, bool _recursive = true); private: void reset(); @@ -837,14 +837,9 @@ ASTPointer<Expression> Parser::parsePrimaryExpression() expression = nodeFactory.createNode<Literal>(token, getLiteralAndAdvance()); break; case Token::Identifier: - { nodeFactory.markEndPosition(); - // if the next token is '(', this identifier looks like function call, - // it could be a contract, event etc. - bool isCallable = m_scanner->peekNextToken() == Token::LParen; - expression = nodeFactory.createNode<Identifier>(getLiteralAndAdvance(), isCallable); + expression = nodeFactory.createNode<Identifier>(getLiteralAndAdvance()); break; - } case Token::LParen: { m_scanner->next(); @@ -77,7 +77,7 @@ public: enum class Category { Integer, IntegerConstant, Bool, Real, Array, - String, Contract, Struct, Function, Enum, + String, Contract, Struct, Function, OverloadedFunctions, Enum, Mapping, Void, TypeType, Modifier, Magic }; @@ -524,6 +524,19 @@ private: Declaration const* m_declaration = nullptr; }; +class OverloadedFunctionType: public Type +{ +public: + explicit OverloadedFunctionType(std::set<Declaration const*> const& _overloadedDeclarations, Identifier* _identifier): + m_overloadedDeclarations(_overloadedDeclarations), m_identifier(_identifier) {} + virtual Category getCategory() const override { return Category::OverloadedFunctions; } + virtual std::string toString() const override { return "OverloadedFunctions"; } + +// private: + std::set<Declaration const*> m_overloadedDeclarations; + Identifier * m_identifier; +}; + /** * The type of a mapping, there is one distinct type per key/value type pair. */ |