diff options
-rw-r--r-- | Changelog.md | 1 | ||||
-rw-r--r-- | libsolidity/codegen/ContractCompiler.cpp | 51 | ||||
-rw-r--r-- | test/libsolidity/Assembly.cpp | 10 | ||||
-rw-r--r-- | test/libsolidity/GasCosts.cpp | 43 |
4 files changed, 90 insertions, 15 deletions
diff --git a/Changelog.md b/Changelog.md index e26ab0f9..63c0280a 100644 --- a/Changelog.md +++ b/Changelog.md @@ -6,6 +6,7 @@ Language Features: Compiler Features: * Inline Assembly: Improve error messages around invalid function argument count. + * Code Generator: Only check callvalue once if all functions are non-payable. * Code Generator: Use codecopy for string constants more aggressively. * Code Generator: Use binary search for dispatch function if more efficient. The size/speed tradeoff can be tuned using ``--optimize-runs``. * Compiler Interface: Disallow unknown keys in standard JSON input. diff --git a/libsolidity/codegen/ContractCompiler.cpp b/libsolidity/codegen/ContractCompiler.cpp index 16c90b60..b051d260 100644 --- a/libsolidity/codegen/ContractCompiler.cpp +++ b/libsolidity/codegen/ContractCompiler.cpp @@ -331,6 +331,25 @@ void ContractCompiler::appendInternalSelector( } } +namespace +{ + +// Helper function to check if any function is payable +bool hasPayableFunctions(ContractDefinition const& _contract) +{ + FunctionDefinition const* fallback = _contract.fallbackFunction(); + if (fallback && fallback->isPayable()) + return true; + + for (auto const& it: _contract.interfaceFunctions()) + if (it.second->isPayable()) + return true; + + return false; +} + +} + void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contract) { map<FixedHash<4>, FunctionTypePointer> interfaceFunctions = _contract.interfaceFunctions(); @@ -342,6 +361,15 @@ void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contrac } FunctionDefinition const* fallback = _contract.fallbackFunction(); + solAssert(!_contract.isLibrary() || !fallback, "Libraries can't have fallback functions"); + + bool needToAddCallvalueCheck = true; + if (!hasPayableFunctions(_contract) && !interfaceFunctions.empty() && !_contract.isLibrary()) + { + appendCallValueCheck(); + needToAddCallvalueCheck = false; + } + eth::AssemblyItem notFound = m_context.newTag(); // directly jump to fallback if the data is too short to contain a function selector // also guards against short data @@ -350,23 +378,26 @@ void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contrac // retrieve the function signature hash from the calldata if (!interfaceFunctions.empty()) + { CompilerUtils(m_context).loadFromMemory(0, IntegerType(CompilerUtils::dataStartOffset * 8), true); - // stack now is: <can-call-non-view-functions>? <funhash> - vector<FixedHash<4>> sortedIDs; - for (auto const& it: interfaceFunctions) - { - callDataUnpackerEntryPoints.insert(std::make_pair(it.first, m_context.newTag())); - sortedIDs.emplace_back(it.first); + // stack now is: <can-call-non-view-functions>? <funhash> + vector<FixedHash<4>> sortedIDs; + for (auto const& it: interfaceFunctions) + { + callDataUnpackerEntryPoints.emplace(it.first, m_context.newTag()); + sortedIDs.emplace_back(it.first); + } + std::sort(sortedIDs.begin(), sortedIDs.end()); + appendInternalSelector(callDataUnpackerEntryPoints, sortedIDs, notFound, m_optimise_runs); } - std::sort(sortedIDs.begin(), sortedIDs.end()); - appendInternalSelector(callDataUnpackerEntryPoints, sortedIDs, notFound, m_optimise_runs); m_context << notFound; + if (fallback) { solAssert(!_contract.isLibrary(), ""); - if (!fallback->isPayable()) + if (!fallback->isPayable() && needToAddCallvalueCheck) appendCallValueCheck(); solAssert(fallback->isFallback(), ""); @@ -396,7 +427,7 @@ void ContractCompiler::appendFunctionSelector(ContractDefinition const& _contrac m_context.setStackOffset(0); // We have to allow this for libraries, because value of the previous // call is still visible in the delegatecall. - if (!functionType->isPayable() && !_contract.isLibrary()) + if (!functionType->isPayable() && !_contract.isLibrary() && needToAddCallvalueCheck) appendCallValueCheck(); // Return tag is used to jump out of the function. diff --git a/test/libsolidity/Assembly.cpp b/test/libsolidity/Assembly.cpp index aa10147c..5d8c89a4 100644 --- a/test/libsolidity/Assembly.cpp +++ b/test/libsolidity/Assembly.cpp @@ -165,14 +165,14 @@ BOOST_AUTO_TEST_CASE(location_test) auto codegenCharStream = make_shared<CharStream>("", "--CODEGEN--"); vector<SourceLocation> locations = - vector<SourceLocation>(hasShifts ? 21 : 22, SourceLocation(2, 82, sourceCode)) + - vector<SourceLocation>(2, SourceLocation(20, 79, sourceCode)) + - vector<SourceLocation>(1, SourceLocation(8, 17, codegenCharStream)) + - vector<SourceLocation>(3, SourceLocation(5, 7, codegenCharStream)) + + vector<SourceLocation>(4, SourceLocation(2, 82, sourceCode)) + + vector<SourceLocation>(1, SourceLocation(8, 17, codegenCharStream)) + + vector<SourceLocation>(3, SourceLocation(5, 7, codegenCharStream)) + vector<SourceLocation>(1, SourceLocation(30, 31, codegenCharStream)) + vector<SourceLocation>(1, SourceLocation(27, 28, codegenCharStream)) + vector<SourceLocation>(1, SourceLocation(20, 32, codegenCharStream)) + - vector<SourceLocation>(1, SourceLocation(5, 7, codegenCharStream)) + + vector<SourceLocation>(1, SourceLocation(5, 7, codegenCharStream)) + + vector<SourceLocation>(hasShifts ? 19 : 20, SourceLocation(2, 82, sourceCode)) + vector<SourceLocation>(24, SourceLocation(20, 79, sourceCode)) + vector<SourceLocation>(1, SourceLocation(49, 58, sourceCode)) + vector<SourceLocation>(1, SourceLocation(72, 74, sourceCode)) + diff --git a/test/libsolidity/GasCosts.cpp b/test/libsolidity/GasCosts.cpp index 15658a91..c7da3ca0 100644 --- a/test/libsolidity/GasCosts.cpp +++ b/test/libsolidity/GasCosts.cpp @@ -82,6 +82,49 @@ BOOST_AUTO_TEST_CASE(string_storage) } } +BOOST_AUTO_TEST_CASE(single_callvaluecheck) +{ + string sourceCode = R"( + // All functions nonpayable, we can check callvalue at the beginning + contract Nonpayable { + address a; + function f(address b) public { + a = b; + } + function f1(address b) public pure returns (uint c) { + return uint(b) + 2; + } + function f2(address b) public pure returns (uint) { + return uint(b) + 8; + } + function f3(address, uint c) pure public returns (uint) { + return c - 5; + } + } + // At least on payable function, we cannot do the optimization. + contract Payable { + address a; + function f(address b) public { + a = b; + } + function f1(address b) public pure returns (uint c) { + return uint(b) + 2; + } + function f2(address b) public pure returns (uint) { + return uint(b) + 8; + } + function f3(address, uint c) payable public returns (uint) { + return c - 5; + } + } + )"; + compileAndRun(sourceCode); + size_t bytecodeSizeNonpayable = m_compiler.object("Nonpayable").bytecode.size(); + size_t bytecodeSizePayable = m_compiler.object("Payable").bytecode.size(); + + BOOST_CHECK_EQUAL(bytecodeSizePayable - bytecodeSizeNonpayable, 26); +} + BOOST_AUTO_TEST_SUITE_END() } |