diff options
Diffstat (limited to 'libsolidity/formal/SolverInterface.h')
-rw-r--r-- | libsolidity/formal/SolverInterface.h | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/libsolidity/formal/SolverInterface.h b/libsolidity/formal/SolverInterface.h index 7f20876e..4a4b3fb1 100644 --- a/libsolidity/formal/SolverInterface.h +++ b/libsolidity/formal/SolverInterface.h @@ -55,7 +55,7 @@ struct Sort Sort(Kind _kind): kind(_kind) {} virtual ~Sort() = default; - bool operator==(Sort const& _other) const { return kind == _other.kind; } + virtual bool operator==(Sort const& _other) const { return kind == _other.kind; } Kind const kind; }; @@ -65,16 +65,22 @@ struct FunctionSort: public Sort { FunctionSort(std::vector<SortPointer> _domain, SortPointer _codomain): Sort(Kind::Function), domain(std::move(_domain)), codomain(std::move(_codomain)) {} - bool operator==(FunctionSort const& _other) const + bool operator==(Sort const& _other) const override { + if (!Sort::operator==(_other)) + return false; + auto _otherFunction = dynamic_cast<FunctionSort const*>(&_other); + solAssert(_otherFunction, ""); + if (domain.size() != _otherFunction->domain.size()) + return false; if (!std::equal( domain.begin(), domain.end(), - _other.domain.begin(), + _otherFunction->domain.begin(), [&](SortPointer _a, SortPointer _b) { return *_a == *_b; } )) return false; - return Sort::operator==(_other) && *codomain == *_other.codomain; + return *codomain == *_otherFunction->codomain; } std::vector<SortPointer> domain; @@ -87,9 +93,13 @@ struct ArraySort: public Sort /// _range is the sort of the values ArraySort(SortPointer _domain, SortPointer _range): Sort(Kind::Array), domain(std::move(_domain)), range(std::move(_range)) {} - bool operator==(ArraySort const& _other) const + bool operator==(Sort const& _other) const override { - return Sort::operator==(_other) && *domain == *_other.domain && *range == *_other.range; + if (!Sort::operator==(_other)) + return false; + auto _otherArray = dynamic_cast<ArraySort const*>(&_other); + solAssert(_otherArray, ""); + return *domain == *_otherArray->domain && *range == *_otherArray->range; } SortPointer domain; |