OASIS
Open Algebra Software
Loading...
Searching...
No Matches
BoundedBinaryExpression.hpp
Go to the documentation of this file.
1//
2// Created by Matthew McCall on 5/2/24.
3//
4
5#ifndef OASIS_BOUNDEDBINARYEXPRESSION_HPP
6#define OASIS_BOUNDEDBINARYEXPRESSION_HPP
7
9#include "Expression.hpp"
10#include "RecursiveCast.hpp"
11#include "Visit.hpp"
12
13namespace Oasis {
24template <template <IExpression, IExpression, IExpression, IExpression> class DerivedT, IExpression MostSigOpT = Expression, IExpression LeastSigOpT = MostSigOpT, IExpression LowerBoundT = Expression, IExpression UpperBoundT = LowerBoundT>
25class BoundedBinaryExpression : public BoundedExpression<LowerBoundT, UpperBoundT> {
26
27 using DerivedSpecialized = DerivedT<MostSigOpT, LeastSigOpT, LowerBoundT, UpperBoundT>;
28 using DerivedGeneralized = DerivedT<Expression, Expression, Expression, Expression>;
29
30public:
33 {
34 if (other.HasMostSigOp()) {
36 }
37
38 if (other.HasLeastSigOp()) {
40 }
41 }
42
43 BoundedBinaryExpression(const MostSigOpT& mostSigOp, const LeastSigOpT& leastSigOp)
44 {
45 SetMostSigOp(mostSigOp);
46 SetLeastSigOp(leastSigOp);
47 }
48
49 [[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
50 {
51 return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
52 }
53
54 auto Copy(tf::Subflow& subflow) const -> std::unique_ptr<Expression> final
55 {
56 DerivedSpecialized copy;
57
58 if (this->mostSigOp) {
59 subflow.emplace([this, &copy](tf::Subflow& sbf) {
60 copy.SetMostSigOp(mostSigOp->Copy(sbf), sbf);
61 });
62 }
63
64 if (this->leastSigOp) {
65 subflow.emplace([this, &copy](tf::Subflow& sbf) {
66 copy.SetLeastSigOp(leastSigOp->Copy(sbf), sbf);
67 });
68 }
69
70 subflow.join();
71
73 }
74 [[nodiscard]] auto Differentiate(const Expression& differentiationVariable) const -> std::unique_ptr<Expression> override
75 {
76 return Generalize()->Differentiate(differentiationVariable);
77 }
78 [[nodiscard]] auto Equals(const Expression& other) const -> bool final
79 {
80 if (this->GetType() != other.GetType()) {
81 return false;
82 }
83
84 const auto otherGeneralized = other.Generalize();
85 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
86
87 bool mostSigOpMismatch = false, leastSigOpMismatch = false;
88
89 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
90 if (mostSigOp && otherBinaryGeneralized.HasMostSigOp()) {
91 mostSigOpMismatch = !mostSigOp->Equals(otherBinaryGeneralized.GetMostSigOp());
92 }
93 } else {
94 mostSigOpMismatch = true;
95 }
96
97 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
98 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
99 leastSigOpMismatch = !leastSigOp->Equals(otherBinaryGeneralized.GetLeastSigOp());
100 }
101 } else {
102 mostSigOpMismatch = true;
103 }
104
105 if (!mostSigOpMismatch && !leastSigOpMismatch) {
106 return true;
107 }
108
109 if (!(this->GetCategory() & Associative)) {
110 return false;
111 }
112
113 auto thisFlattened = std::vector<std::unique_ptr<Expression>> {};
114 auto otherFlattened = std::vector<std::unique_ptr<Expression>> {};
115
116 this->Flatten(thisFlattened);
117 otherBinaryGeneralized.Flatten(otherFlattened);
118
119 for (const auto& thisOperand : thisFlattened) {
120 if (std::find_if(otherFlattened.begin(), otherFlattened.end(), [&thisOperand](const auto& otherOperand) {
121 return thisOperand->Equals(*otherOperand);
122 })
123 == otherFlattened.end()) {
124 return false;
125 }
126 }
127
128 return true;
129 }
130
131 [[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
132 {
133 DerivedGeneralized generalized;
134
135 if (this->mostSigOp) {
136 generalized.SetMostSigOp(*this->mostSigOp->Copy());
137 }
138
139 if (this->leastSigOp) {
140 generalized.SetLeastSigOp(*this->leastSigOp->Copy());
141 }
142
143 return std::make_unique<DerivedGeneralized>(generalized);
144 }
145
146 auto Generalize(tf::Subflow& subflow) const -> std::unique_ptr<Expression> final
147 {
148 DerivedGeneralized generalized;
149
150 if (this->mostSigOp) {
151 subflow.emplace([this, &generalized](tf::Subflow& sbf) {
152 generalized.SetMostSigOp(*this->mostSigOp->Copy(sbf));
153 });
154 }
155
156 if (this->leastSigOp) {
157 subflow.emplace([this, &generalized](tf::Subflow& sbf) {
158 generalized.SetLeastSigOp(*this->leastSigOp->Copy(sbf));
159 });
160 }
161
162 subflow.join();
163
164 return std::make_unique<DerivedGeneralized>(generalized);
165 }
166
167 [[nodiscard]] auto Simplify() const -> std::unique_ptr<Expression> override
168 {
169 return Generalize()->Simplify();
170 }
171
172 auto Simplify(tf::Subflow& subflow) const -> std::unique_ptr<Expression> override
173 {
174 std::unique_ptr<Expression> generalized, simplified;
175
176 tf::Task generalizeTask = subflow.emplace([this, &generalized](tf::Subflow& sbf) {
177 generalized = Generalize(sbf);
178 });
179
180 tf::Task simplifyTask = subflow.emplace([&generalized, &simplified](tf::Subflow& sbf) {
181 simplified = generalized->Simplify(sbf);
182 });
183
184 simplifyTask.succeed(generalizeTask);
185 subflow.join();
186
187 return simplified;
188 }
189
190 [[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
191 {
192 if (this->GetType() != other.GetType()) {
193 return false;
194 }
195
196 const std::unique_ptr<Expression> otherGeneralized = other.Generalize();
197 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
198
199 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
200 if (this->HasMostSigOp() && otherBinaryGeneralized.HasMostSigOp()) {
201 if (!mostSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetMostSigOp())) {
202 return false;
203 }
204 }
205 }
206
207 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
208 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
209 if (!leastSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetLeastSigOp())) {
210 return false;
211 }
212 }
213 }
214
215 return true;
216 }
217
218 auto StructurallyEquivalent(const Expression& other, tf::Subflow& subflow) const -> bool final
219 {
220 if (this->GetType() != other.GetType()) {
221 return false;
222 }
223
224 std::unique_ptr<Expression> otherGeneralized;
225
226 tf::Task generalizeTask = subflow.emplace([&](tf::Subflow& sbf) {
227 otherGeneralized = other.Generalize(sbf);
228 });
229
230 bool mostSigOpEquivalent = false, leastSigOpEquivalent = false;
231
232 if (this->mostSigOp) {
233 tf::Task compMostSigOp = subflow.emplace([this, &otherGeneralized, &mostSigOpEquivalent](tf::Subflow& sbf) {
234 if (const auto& otherBinary = static_cast<const DerivedGeneralized&>(*otherGeneralized); otherBinary.HasMostSigOp()) {
235 mostSigOpEquivalent = mostSigOp->StructurallyEquivalent(otherBinary.GetMostSigOp(), sbf);
236 }
237 });
238
239 compMostSigOp.succeed(generalizeTask);
240 }
241
242 if (this->leastSigOp) {
243 tf::Task compLeastSigOp = subflow.emplace([this, &otherGeneralized, &leastSigOpEquivalent](tf::Subflow& sbf) {
244 if (const auto& otherBinary = static_cast<const DerivedGeneralized&>(*otherGeneralized); otherBinary.HasLeastSigOp()) {
245 leastSigOpEquivalent = leastSigOp->StructurallyEquivalent(otherBinary.GetLeastSigOp(), sbf);
246 }
247 });
248
249 compLeastSigOp.succeed(generalizeTask);
250 }
251
252 subflow.join();
253
254 return mostSigOpEquivalent && leastSigOpEquivalent;
255 }
256
269 {
270 if (this->mostSigOp->template Is<DerivedT>()) {
271 auto generalizedMostSigOp = this->mostSigOp->Generalize();
272 const auto& mostSigOp = static_cast<const DerivedGeneralized&>(*generalizedMostSigOp);
273 mostSigOp.Flatten(out);
274 } else {
275 out.push_back(this->mostSigOp->Copy());
276 }
277
278 if (this->leastSigOp->template Is<DerivedT>()) {
279 auto generalizedLeastSigOp = this->leastSigOp->Generalize();
280 const auto& leastSigOp = static_cast<const DerivedGeneralized&>(*generalizedLeastSigOp);
281 leastSigOp.Flatten(out);
282 } else {
283 out.push_back(this->leastSigOp->Copy());
284 }
285 }
286
291 auto GetMostSigOp() const -> const MostSigOpT&
292 {
293 assert(mostSigOp != nullptr);
294 return *mostSigOp;
295 }
296
301 auto GetLeastSigOp() const -> const LeastSigOpT&
302 {
303 assert(leastSigOp != nullptr);
304 return *leastSigOp;
305 }
306
311 [[nodiscard]] auto HasMostSigOp() const -> bool
312 {
313 return mostSigOp != nullptr;
314 }
315
320 [[nodiscard]] auto HasLeastSigOp() const -> bool
321 {
322 return leastSigOp != nullptr;
323 }
324
329 template <typename T>
331 auto SetMostSigOp(const T& op) -> bool
332 {
334 this->mostSigOp = op.Copy();
335 return true;
336 }
337
339 this->mostSigOp = std::make_unique<MostSigOpT>(op);
340 return true;
341 }
342
343 if (auto castedOp = Oasis::RecursiveCast<MostSigOpT>(op); castedOp) {
344 mostSigOp = std::move(castedOp);
345 return true;
346 }
347
348 return false;
349 }
350
355 template <typename T>
357 auto SetLeastSigOp(const T& op) -> bool
358 {
360 this->leastSigOp = op.Copy();
361 return true;
362 }
363
365 this->leastSigOp = std::make_unique<LeastSigOpT>(op);
366 return true;
367 }
368
369 if (auto castedOp = Oasis::RecursiveCast<LeastSigOpT>(op); castedOp) {
370 leastSigOp = std::move(castedOp);
371 return true;
372 }
373
374 return false;
375 }
376
377 auto Substitute(const Expression& var, const Expression& val) -> std::unique_ptr<Expression> override
378 {
379 const std::unique_ptr<Expression> left = GetMostSigOp().Substitute(var, val);
380 const std::unique_ptr<Expression> right = GetLeastSigOp().Substitute(var, val);
381 DerivedGeneralized comb { *left, *right };
382 auto ret = comb.Simplify();
383 return ret;
384 }
389 auto SwapOperands() const -> DerivedSpecialized
390 {
391 return DerivedT { *this->leastSigOp, *this->mostSigOp };
392 }
393
395
396 void Serialize(SerializationVisitor& visitor) const override
397 {
398 const auto generalized = Generalize();
399 const auto& derivedGeneralized = dynamic_cast<const DerivedGeneralized&>(*generalized);
400 visitor.Serialize(derivedGeneralized);
401 }
402
403private:
406};
407
408}
409
410#endif // OASIS_BOUNDEDBINARYEXPRESSION_HPP
A concept for an operand of a binary expression with bounds.
Definition BoundedBinaryExpression.hpp:25
auto Differentiate(const Expression &differentiationVariable) const -> std::unique_ptr< Expression > override
Tries to differentiate this function.
Definition BoundedBinaryExpression.hpp:74
auto SwapOperands() const -> DerivedSpecialized
Swaps the operands of this expression.
Definition BoundedBinaryExpression.hpp:389
auto GetMostSigOp() const -> const MostSigOpT &
Gets the most significant operand of this expression.
Definition BoundedBinaryExpression.hpp:291
auto Flatten(std::vector< std::unique_ptr< Expression > > &out) const -> void
Flattens this expression.
Definition BoundedBinaryExpression.hpp:268
auto Simplify(tf::Subflow &subflow) const -> std::unique_ptr< Expression > override
Definition BoundedBinaryExpression.hpp:172
BoundedBinaryExpression(const MostSigOpT &mostSigOp, const LeastSigOpT &leastSigOp)
Definition BoundedBinaryExpression.hpp:43
auto Simplify() const -> std::unique_ptr< Expression > override
Simplifies this expression.
Definition BoundedBinaryExpression.hpp:167
auto Substitute(const Expression &var, const Expression &val) -> std::unique_ptr< Expression > override
Definition BoundedBinaryExpression.hpp:377
auto Equals(const Expression &other) const -> bool final
Compares this expression to another expression for equality.
Definition BoundedBinaryExpression.hpp:78
auto SetMostSigOp(const T &op) -> bool
Sets the most significant operand of this expression.
Definition BoundedBinaryExpression.hpp:331
auto Generalize() const -> std::unique_ptr< Expression > final
Converts this expression to a more general expression.
Definition BoundedBinaryExpression.hpp:131
auto HasMostSigOp() const -> bool
Gets whether this expression has a most significant operand.
Definition BoundedBinaryExpression.hpp:311
void Serialize(SerializationVisitor &visitor) const override
Definition BoundedBinaryExpression.hpp:396
auto GetLeastSigOp() const -> const LeastSigOpT &
Gets the least significant operand of this expression.
Definition BoundedBinaryExpression.hpp:301
auto HasLeastSigOp() const -> bool
Gets whether this expression has a least significant operand.
Definition BoundedBinaryExpression.hpp:320
auto SetLeastSigOp(const T &op) -> bool
Sets the least significant operand of this expression.
Definition BoundedBinaryExpression.hpp:357
auto Generalize(tf::Subflow &subflow) const -> std::unique_ptr< Expression > final
Definition BoundedBinaryExpression.hpp:146
auto Copy(tf::Subflow &subflow) const -> std::unique_ptr< Expression > final
Definition BoundedBinaryExpression.hpp:54
auto Copy() const -> std::unique_ptr< Expression > final
Copies this expression.
Definition BoundedBinaryExpression.hpp:49
auto StructurallyEquivalent(const Expression &other) const -> bool final
Checks whether this expression is structurally equivalent to another expression.
Definition BoundedBinaryExpression.hpp:190
auto operator=(const BoundedBinaryExpression &other) -> BoundedBinaryExpression &=default
auto StructurallyEquivalent(const Expression &other, tf::Subflow &subflow) const -> bool final
Definition BoundedBinaryExpression.hpp:218
BoundedBinaryExpression(const BoundedBinaryExpression &other)
Definition BoundedBinaryExpression.hpp:32
A concept base class for both Unary and BoundedBinary expressions.
Definition BoundedExpression.hpp:23
An expression.
Definition Expression.hpp:62
virtual auto GetCategory() const -> uint32_t
Gets the category of this expression.
Definition Expression.cpp:212
virtual auto GetType() const -> ExpressionType
Gets the type of this expression.
Definition Expression.cpp:220
virtual auto Simplify() const -> std::unique_ptr< Expression >
Simplifies this expression.
Definition Expression.cpp:244
Checks if type T is same as any of the provided types in U.
Definition Concepts.hpp:51
T find_if(T... args)
T is_same_v
Definition Add.hpp:11
@ Associative
Definition Expression.hpp:49