OASIS
Open Algebra Software
Loading...
Searching...
No Matches
BinaryExpression.hpp
Go to the documentation of this file.
1//
2// Created by Matthew McCall on 7/2/23.
3//
4
5#ifndef OASIS_BINARYEXPRESSION_HPP
6#define OASIS_BINARYEXPRESSION_HPP
7
8#include <algorithm>
9#include <cassert>
10#include <functional>
11#include <list>
12
13#include "Expression.hpp"
14#include "RecursiveCast.hpp"
15#include "Visit.hpp"
16
17namespace Oasis {
27template <typename MostSigOpT, typename LeastSigOpT, typename T>
29
30template <template <typename, typename> typename T>
31concept IAssociativeAndCommutative = IExpression<T<Expression, Expression>> && ((T<Expression, Expression>::GetStaticCategory() & (Associative | Commutative)) == (Associative | Commutative));
32
39template <template <typename, typename> typename T>
42{
43 if (ops.size() <= 1) {
44 return nullptr;
45 }
46
47 using GeneralizedT = T<Expression, Expression>;
48
50 opsList.resize(ops.size());
51
52 std::transform(ops.begin(), ops.end(), opsList.begin(), [](const auto& op) { return op->Copy(); });
53
54 while (std::next(opsList.begin()) != opsList.end()) {
55 for (auto i = opsList.begin(); i != opsList.end() && std::next(i) != opsList.end();) {
56 auto node = std::make_unique<GeneralizedT>(**i, **std::next(i));
57 opsList.insert(i, std::move(node));
58 i = opsList.erase(i, std::next(i, 2));
59 }
60 }
61
62 auto* result = dynamic_cast<GeneralizedT*>(opsList.front().release());
63 return std::unique_ptr<GeneralizedT>(result);
64}
65
81template <template <IExpression, IExpression> class DerivedT, IExpression MostSigOpT = Expression, IExpression LeastSigOpT = MostSigOpT>
83
84 using DerivedSpecialized = DerivedT<MostSigOpT, LeastSigOpT>;
85 using DerivedGeneralized = DerivedT<Expression, Expression>;
86
87public:
88 BinaryExpression() = default;
90 {
91 if (other.HasMostSigOp()) {
93 }
94
95 if (other.HasLeastSigOp()) {
97 }
98 }
99
100 BinaryExpression(const MostSigOpT& mostSigOp, const LeastSigOpT& leastSigOp)
101 {
104 }
105
106 template <IExpression Op1T, IExpression Op2T, IExpression... OpsT>
107 BinaryExpression(const Op1T& op1, const Op2T& op2, const OpsT&... ops)
108 {
109 static_assert(IAssociativeAndCommutative<DerivedT>, "List initializer only supported for associative and commutative expressions");
110 static_assert(std::is_same_v<DerivedGeneralized, DerivedSpecialized>, "List initializer only supported for generalized expressions");
111
113
114 for (auto opWrapper : std::vector<std::reference_wrapper<const Expression>> { static_cast<const Expression&>(op1), static_cast<const Expression&>(op2), (static_cast<const Expression&>(ops))... }) {
115 const Expression& operand = opWrapper.get();
116 opsVec.emplace_back(operand.Copy());
117 }
118
119 // build expression from vector
120 auto generalized = BuildFromVector<DerivedT>(opsVec);
121
122 SetLeastSigOp(generalized->GetLeastSigOp());
123 SetMostSigOp(generalized->GetMostSigOp());
124 }
125
126 [[nodiscard]] auto Copy() const -> std::unique_ptr<Expression> final
127 {
128 return std::make_unique<DerivedSpecialized>(*static_cast<const DerivedSpecialized*>(this));
129 }
130
131 [[nodiscard]] auto Differentiate(const Expression& differentiationVariable) const -> std::unique_ptr<Expression> override
132 {
133 return Generalize()->Differentiate(differentiationVariable);
134 }
135 [[nodiscard]] auto Equals(const Expression& other) const -> bool final
136 {
137 if (this->GetType() != other.GetType()) {
138 return false;
139 }
140
141 const auto otherGeneralized = other.Generalize();
142 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
143
144 bool mostSigOpMismatch = false, leastSigOpMismatch = false;
145
146 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
147 if (mostSigOp && otherBinaryGeneralized.HasMostSigOp()) {
148 mostSigOpMismatch = !mostSigOp->Equals(otherBinaryGeneralized.GetMostSigOp());
149 }
150 } else {
151 mostSigOpMismatch = true;
152 }
153
154 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
155 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
156 leastSigOpMismatch = !leastSigOp->Equals(otherBinaryGeneralized.GetLeastSigOp());
157 }
158 } else {
159 mostSigOpMismatch = true;
160 }
161
162 if (!mostSigOpMismatch && !leastSigOpMismatch) {
163 return true;
164 }
165
166 if (!(this->GetCategory() & Associative)) {
167 return false;
168 }
169
170 auto thisFlattened = std::vector<std::unique_ptr<Expression>> {};
171 auto otherFlattened = std::vector<std::unique_ptr<Expression>> {};
172
173 this->Flatten(thisFlattened);
174 otherBinaryGeneralized.Flatten(otherFlattened);
175
176 for (const auto& thisOperand : thisFlattened) {
177 if (std::find_if(otherFlattened.begin(), otherFlattened.end(), [&thisOperand](const auto& otherOperand) {
178 return thisOperand->Equals(*otherOperand);
179 })
180 == otherFlattened.end()) {
181 return false;
182 }
183 }
184
185 return true;
186 }
187
188 [[nodiscard]] auto Generalize() const -> std::unique_ptr<Expression> final
189 {
190 DerivedGeneralized generalized;
191
192 if (this->mostSigOp) {
193 generalized.SetMostSigOp(*this->mostSigOp->Copy());
194 }
195
196 if (this->leastSigOp) {
197 generalized.SetLeastSigOp(*this->leastSigOp->Copy());
198 }
199
200 return std::make_unique<DerivedGeneralized>(generalized);
201 }
202
203 [[nodiscard]] auto Simplify() const -> std::unique_ptr<Expression> override
204 {
205 return Generalize()->Simplify();
206 }
207
208 [[nodiscard]] auto Integrate(const Expression& integrationVariable) const -> std::unique_ptr<Expression> override
209 {
210 return Generalize()->Integrate(integrationVariable);
211 }
212
213 [[nodiscard]] auto StructurallyEquivalent(const Expression& other) const -> bool final
214 {
215 if (this->GetType() != other.GetType()) {
216 return false;
217 }
218
219 const std::unique_ptr<Expression> otherGeneralized = other.Generalize();
220 const auto& otherBinaryGeneralized = static_cast<const DerivedGeneralized&>(*otherGeneralized);
221
222 if (this->HasMostSigOp() == otherBinaryGeneralized.HasMostSigOp()) {
223 if (this->HasMostSigOp() && otherBinaryGeneralized.HasMostSigOp()) {
224 if (!mostSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetMostSigOp())) {
225 return false;
226 }
227 }
228 }
229
230 if (this->HasLeastSigOp() == otherBinaryGeneralized.HasLeastSigOp()) {
231 if (this->HasLeastSigOp() && otherBinaryGeneralized.HasLeastSigOp()) {
232 if (!leastSigOp->StructurallyEquivalent(otherBinaryGeneralized.GetLeastSigOp())) {
233 return false;
234 }
235 }
236 }
237
238 return true;
239 }
240
253 {
254 if (mostSigOp) {
255 if (this->mostSigOp->template Is<DerivedGeneralized>()) {
256 auto generalizedMostSigOp = this->mostSigOp->Generalize();
257 const auto& mostSigOp = static_cast<const DerivedGeneralized&>(*generalizedMostSigOp);
258 mostSigOp.Flatten(out);
259 } else {
260 out.push_back(this->mostSigOp->Copy());
261 }
262 }
263
264 if (leastSigOp) {
265 if (this->leastSigOp->template Is<DerivedGeneralized>()) {
266 auto generalizedLeastSigOp = this->leastSigOp->Generalize();
267 const auto& leastSigOp = static_cast<const DerivedGeneralized&>(*generalizedLeastSigOp);
268 leastSigOp.Flatten(out);
269 } else {
270 out.push_back(this->leastSigOp->Copy());
271 }
272 }
273 }
274
279 auto GetMostSigOp() const -> const MostSigOpT&
280 {
281 assert(mostSigOp != nullptr);
282 return *mostSigOp;
283 }
284
289 auto GetLeastSigOp() const -> const LeastSigOpT&
290 {
291 assert(leastSigOp != nullptr);
292 return *leastSigOp;
293 }
294
299 [[nodiscard]] auto HasMostSigOp() const -> bool
300 {
301 return mostSigOp != nullptr;
302 }
303
308 [[nodiscard]] auto HasLeastSigOp() const -> bool
309 {
310 return leastSigOp != nullptr;
311 }
312
317 template <typename T>
319 auto SetMostSigOp(const T& op) -> bool
320 {
322 this->mostSigOp = op.Copy();
323 return true;
324 }
325
328 return true;
329 }
330
331 if (auto castedOp = Oasis::RecursiveCast<MostSigOpT>(op); castedOp) {
332 mostSigOp = std::move(castedOp);
333 return true;
334 }
335
336 return false;
337 }
338
343 template <typename T>
345 auto SetLeastSigOp(const T& op) -> bool
346 {
348 this->leastSigOp = op.Copy();
349 return true;
350 }
351
354 return true;
355 }
356
357 if (auto castedOp = Oasis::RecursiveCast<LeastSigOpT>(op); castedOp) {
358 leastSigOp = std::move(castedOp);
359 return true;
360 }
361
362 return false;
363 }
364
365 auto Substitute(const Expression& var, const Expression& val) -> std::unique_ptr<Expression> override
366 {
368 std::unique_ptr<Expression> right = ((GetLeastSigOp().Copy())->Substitute(var, val));
369 DerivedT<Expression, Expression> comb = DerivedT<Expression, Expression> { *left, *right };
370 auto ret = comb.Simplify();
371 return ret;
372 }
377 auto SwapOperands() const -> DerivedT<LeastSigOpT, MostSigOpT>
378 {
379 return DerivedT { *this->leastSigOp, *this->mostSigOp };
380 }
381
382 auto operator=(const BinaryExpression& other) -> BinaryExpression& = default;
383
384 auto AcceptInternal(Visitor& visitor) const -> any override
385 {
386 const auto generalized = Generalize();
387 const auto& derivedGeneralized = dynamic_cast<const DerivedGeneralized&>(*generalized);
388 return visitor.Visit(derivedGeneralized);
389 }
390
393};
394
395} // Oasis
396
397#endif // OASIS_BINARYEXPRESSION_HPP
T begin(T... args)
A binary expression.
Definition BinaryExpression.hpp:82
auto Simplify() const -> std::unique_ptr< Expression > override
Simplifies this expression.
Definition BinaryExpression.hpp:203
auto Copy() const -> std::unique_ptr< Expression > final
Copies this expression.
Definition BinaryExpression.hpp:126
BinaryExpression(const BinaryExpression &other)
Definition BinaryExpression.hpp:89
std::unique_ptr< LeastSigOpT > leastSigOp
Definition BinaryExpression.hpp:392
auto Differentiate(const Expression &differentiationVariable) const -> std::unique_ptr< Expression > override
Tries to differentiate this function.
Definition BinaryExpression.hpp:131
auto AcceptInternal(Visitor &visitor) const -> any override
This function serializes the expression object.
Definition BinaryExpression.hpp:384
auto Substitute(const Expression &var, const Expression &val) -> std::unique_ptr< Expression > override
Definition BinaryExpression.hpp:365
auto GetLeastSigOp() const -> const LeastSigOpT &
Gets the least significant operand of this expression.
Definition BinaryExpression.hpp:289
auto operator=(const BinaryExpression &other) -> BinaryExpression &=default
auto SetLeastSigOp(const T &op) -> bool
Sets the least significant operand of this expression.
Definition BinaryExpression.hpp:345
auto Integrate(const Expression &integrationVariable) const -> std::unique_ptr< Expression > override
Attempts to integrate this expression using integration rules.
Definition BinaryExpression.hpp:208
auto Generalize() const -> std::unique_ptr< Expression > final
Converts this expression to a more general expression.
Definition BinaryExpression.hpp:188
auto GetMostSigOp() const -> const MostSigOpT &
Gets the most significant operand of this expression.
Definition BinaryExpression.hpp:279
std::unique_ptr< MostSigOpT > mostSigOp
Definition BinaryExpression.hpp:391
auto SwapOperands() const -> DerivedT< LeastSigOpT, MostSigOpT >
Swaps the operands of this expression.
Definition BinaryExpression.hpp:377
auto SetMostSigOp(const T &op) -> bool
Sets the most significant operand of this expression.
Definition BinaryExpression.hpp:319
auto StructurallyEquivalent(const Expression &other) const -> bool final
Checks whether this expression is structurally equivalent to another expression.
Definition BinaryExpression.hpp:213
auto Flatten(std::vector< std::unique_ptr< Expression > > &out) const -> void
Flattens this expression.
Definition BinaryExpression.hpp:252
auto HasLeastSigOp() const -> bool
Gets whether this expression has a least significant operand.
Definition BinaryExpression.hpp:308
BinaryExpression(const Op1T &op1, const Op2T &op2, const OpsT &... ops)
Definition BinaryExpression.hpp:107
auto Equals(const Expression &other) const -> bool final
Compares this expression to another expression for equality.
Definition BinaryExpression.hpp:135
BinaryExpression(const MostSigOpT &mostSigOp, const LeastSigOpT &leastSigOp)
Definition BinaryExpression.hpp:100
auto HasMostSigOp() const -> bool
Gets whether this expression has a most significant operand.
Definition BinaryExpression.hpp:299
An expression.
Definition Expression.hpp:62
virtual auto Copy() const -> std::unique_ptr< Expression >=0
Copies this expression.
virtual auto GetCategory() const -> uint32_t
Gets the category of this expression.
Definition Expression.cpp:212
virtual auto GetType() const -> ExpressionType
Gets the type of this expression.
Definition Expression.cpp:220
virtual auto Simplify() const -> std::unique_ptr< Expression >
Simplifies this expression.
Definition Expression.cpp:244
Definition Visit.hpp:50
Definition BinaryExpression.hpp:31
An expression concept.
Definition Concepts.hpp:28
A concept for an operand of a binary expression.
Definition BinaryExpression.hpp:28
Checks if type T is same as any of the provided types in U.
Definition Concepts.hpp:51
T emplace_back(T... args)
T end(T... args)
T erase(T... args)
T find_if(T... args)
T front(T... args)
T insert(T... args)
T is_same_v
Definition Add.hpp:11
auto BuildFromVector(const std::vector< std::unique_ptr< Expression > > &ops) -> std::unique_ptr< T< Expression, Expression > >
Builds a reasonably balanced binary expression from a vector of operands.
Definition BinaryExpression.hpp:41
boost::anys::unique_any any
Definition Expression.hpp:15
@ Commutative
Definition Expression.hpp:50
@ Associative
Definition Expression.hpp:49
T next(T... args)
T resize(T... args)
T transform(T... args)