c9455910aa41106e4951eebe5e162f16e692d07d
[lldb.git] / flang / include / flang / Evaluate / traverse.h
1 //===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef FORTRAN_EVALUATE_TRAVERSE_H_
10 #define FORTRAN_EVALUATE_TRAVERSE_H_
11
12 // A utility for scanning all of the constituent objects in an Expr<>
13 // expression representation using a collection of mutually recursive
14 // functions to compose a function object.
15 //
16 // The class template Traverse<> below implements a function object that
17 // can handle every type that can appear in or around an Expr<>.
18 // Each of its overloads for operator() should be viewed as a *default*
19 // handler; some of these must be overridden by the client to accomplish
20 // its particular task.
21 //
22 // The client (Visitor) of Traverse<Visitor,Result> must define:
23 // - a member function "Result Default();"
24 // - a member function "Result Combine(Result &&, Result &&)"
25 // - overrides for "Result operator()"
26 //
27 // Boilerplate classes also appear below to ease construction of visitors.
28 // See CheckSpecificationExpr() in check-expression.cpp for an example client.
29 //
30 // How this works:
31 // - The operator() overloads in Traverse<> invoke the visitor's Default() for
32 //   expression leaf nodes.  They invoke the visitor's operator() for the
33 //   subtrees of interior nodes, and the visitor's Combine() to merge their
34 //   results together.
35 // - Overloads of operator() in each visitor handle the cases of interest.
36
37 #include "expression.h"
38 #include "flang/Semantics/symbol.h"
39 #include "flang/Semantics/type.h"
40 #include <set>
41 #include <type_traits>
42
43 namespace Fortran::evaluate {
44 template <typename Visitor, typename Result> class Traverse {
45 public:
46   explicit Traverse(Visitor &v) : visitor_{v} {}
47
48   // Packaging
49   template <typename A, bool C>
50   Result operator()(const common::Indirection<A, C> &x) const {
51     return visitor_(x.value());
52   }
53   template <typename A> Result operator()(const SymbolRef x) const {
54     return visitor_(*x);
55   }
56   template <typename A> Result operator()(const std::unique_ptr<A> &x) const {
57     return visitor_(x.get());
58   }
59   template <typename A> Result operator()(const std::shared_ptr<A> &x) const {
60     return visitor_(x.get());
61   }
62   template <typename A> Result operator()(const A *x) const {
63     if (x) {
64       return visitor_(*x);
65     } else {
66       return visitor_.Default();
67     }
68   }
69   template <typename A> Result operator()(const std::optional<A> &x) const {
70     if (x) {
71       return visitor_(*x);
72     } else {
73       return visitor_.Default();
74     }
75   }
76   template <typename... A>
77   Result operator()(const std::variant<A...> &u) const {
78     return std::visit(visitor_, u);
79   }
80   template <typename A> Result operator()(const std::vector<A> &x) const {
81     return CombineContents(x);
82   }
83
84   // Leaves
85   Result operator()(const BOZLiteralConstant &) const {
86     return visitor_.Default();
87   }
88   Result operator()(const NullPointer &) const { return visitor_.Default(); }
89   template <typename T> Result operator()(const Constant<T> &x) const {
90     if constexpr (T::category == TypeCategory::Derived) {
91       std::optional<Result> result;
92       for (const StructureConstructorValues &map : x.values()) {
93         for (const auto &pair : map) {
94           auto value{visitor_(pair.second.value())};
95           result = result
96               ? visitor_.Combine(std::move(*result), std::move(value))
97               : std::move(value);
98         }
99       }
100       return result ? *result : visitor_.Default();
101     } else {
102       return visitor_.Default();
103     }
104   }
105   Result operator()(const Symbol &) const { return visitor_.Default(); }
106   Result operator()(const StaticDataObject &) const {
107     return visitor_.Default();
108   }
109   Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); }
110
111   // Variables
112   Result operator()(const BaseObject &x) const { return visitor_(x.u); }
113   Result operator()(const Component &x) const {
114     return Combine(x.base(), x.GetLastSymbol());
115   }
116   Result operator()(const NamedEntity &x) const {
117     if (const Component * component{x.UnwrapComponent()}) {
118       return visitor_(*component);
119     } else {
120       return visitor_(x.GetFirstSymbol());
121     }
122   }
123   Result operator()(const TypeParamInquiry &x) const {
124     return visitor_(x.base());
125   }
126   Result operator()(const Triplet &x) const {
127     return Combine(x.lower(), x.upper(), x.stride());
128   }
129   Result operator()(const Subscript &x) const { return visitor_(x.u); }
130   Result operator()(const ArrayRef &x) const {
131     return Combine(x.base(), x.subscript());
132   }
133   Result operator()(const CoarrayRef &x) const {
134     return Combine(
135         x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team());
136   }
137   Result operator()(const DataRef &x) const { return visitor_(x.u); }
138   Result operator()(const Substring &x) const {
139     return Combine(x.parent(), x.lower(), x.upper());
140   }
141   Result operator()(const ComplexPart &x) const {
142     return visitor_(x.complex());
143   }
144   template <typename T> Result operator()(const Designator<T> &x) const {
145     return visitor_(x.u);
146   }
147   template <typename T> Result operator()(const Variable<T> &x) const {
148     return visitor_(x.u);
149   }
150   Result operator()(const DescriptorInquiry &x) const {
151     return visitor_(x.base());
152   }
153
154   // Calls
155   Result operator()(const SpecificIntrinsic &) const {
156     return visitor_.Default();
157   }
158   Result operator()(const ProcedureDesignator &x) const {
159     if (const Component * component{x.GetComponent()}) {
160       return visitor_(*component);
161     } else if (const Symbol * symbol{x.GetSymbol()}) {
162       return visitor_(*symbol);
163     } else {
164       return visitor_(DEREF(x.GetSpecificIntrinsic()));
165     }
166   }
167   Result operator()(const ActualArgument &x) const {
168     if (const auto *symbol{x.GetAssumedTypeDummy()}) {
169       return visitor_(*symbol);
170     } else {
171       return visitor_(x.UnwrapExpr());
172     }
173   }
174   Result operator()(const ProcedureRef &x) const {
175     return Combine(x.proc(), x.arguments());
176   }
177   template <typename T> Result operator()(const FunctionRef<T> &x) const {
178     return visitor_(static_cast<const ProcedureRef &>(x));
179   }
180
181   // Other primaries
182   template <typename T>
183   Result operator()(const ArrayConstructorValue<T> &x) const {
184     return visitor_(x.u);
185   }
186   template <typename T>
187   Result operator()(const ArrayConstructorValues<T> &x) const {
188     return CombineContents(x);
189   }
190   template <typename T> Result operator()(const ImpliedDo<T> &x) const {
191     return Combine(x.lower(), x.upper(), x.stride(), x.values());
192   }
193   Result operator()(const semantics::ParamValue &x) const {
194     return visitor_(x.GetExplicit());
195   }
196   Result operator()(
197       const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const {
198     return visitor_(x.second);
199   }
200   Result operator()(const semantics::DerivedTypeSpec &x) const {
201     return CombineContents(x.parameters());
202   }
203   Result operator()(const StructureConstructorValues::value_type &x) const {
204     return visitor_(x.second);
205   }
206   Result operator()(const StructureConstructor &x) const {
207     return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
208   }
209
210   // Operations and wrappers
211   template <typename D, typename R, typename O>
212   Result operator()(const Operation<D, R, O> &op) const {
213     return visitor_(op.left());
214   }
215   template <typename D, typename R, typename LO, typename RO>
216   Result operator()(const Operation<D, R, LO, RO> &op) const {
217     return Combine(op.left(), op.right());
218   }
219   Result operator()(const Relational<SomeType> &x) const {
220     return visitor_(x.u);
221   }
222   template <typename T> Result operator()(const Expr<T> &x) const {
223     return visitor_(x.u);
224   }
225
226 private:
227   template <typename ITER> Result CombineRange(ITER iter, ITER end) const {
228     if (iter == end) {
229       return visitor_.Default();
230     } else {
231       Result result{visitor_(*iter++)};
232       for (; iter != end; ++iter) {
233         result = visitor_.Combine(std::move(result), visitor_(*iter));
234       }
235       return result;
236     }
237   }
238
239   template <typename A> Result CombineContents(const A &x) const {
240     return CombineRange(x.begin(), x.end());
241   }
242
243   template <typename A, typename... Bs>
244   Result Combine(const A &x, const Bs &...ys) const {
245     if constexpr (sizeof...(Bs) == 0) {
246       return visitor_(x);
247     } else {
248       return visitor_.Combine(visitor_(x), Combine(ys...));
249     }
250   }
251
252   Visitor &visitor_;
253 };
254
255 // For validity checks across an expression: if any operator() result is
256 // false, so is the overall result.
257 template <typename Visitor, bool DefaultValue,
258     typename Base = Traverse<Visitor, bool>>
259 struct AllTraverse : public Base {
260   explicit AllTraverse(Visitor &v) : Base{v} {}
261   using Base::operator();
262   static bool Default() { return DefaultValue; }
263   static bool Combine(bool x, bool y) { return x && y; }
264 };
265
266 // For searches over an expression: the first operator() result that
267 // is truthful is the final result.  Works for Booleans, pointers,
268 // and std::optional<>.
269 template <typename Visitor, typename Result = bool,
270     typename Base = Traverse<Visitor, Result>>
271 class AnyTraverse : public Base {
272 public:
273   explicit AnyTraverse(Visitor &v) : Base{v} {}
274   using Base::operator();
275   Result Default() const { return default_; }
276   static Result Combine(Result &&x, Result &&y) {
277     if (x) {
278       return std::move(x);
279     } else {
280       return std::move(y);
281     }
282   }
283
284 private:
285   Result default_{};
286 };
287
288 template <typename Visitor, typename Set,
289     typename Base = Traverse<Visitor, Set>>
290 struct SetTraverse : public Base {
291   explicit SetTraverse(Visitor &v) : Base{v} {}
292   using Base::operator();
293   static Set Default() { return {}; }
294   static Set Combine(Set &&x, Set &&y) {
295 #if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES)
296     x.merge(y);
297 #else
298     // std::set::merge() not available (yet)
299     for (auto &value : y) {
300       x.insert(std::move(value));
301     }
302 #endif
303     return std::move(x);
304   }
305 };
306
307 } // namespace Fortran::evaluate
308 #endif