codac 2.0.0
Loading...
Searching...
No Matches
codac2_OctaSym_operator.h
Go to the documentation of this file.
1
9
10#pragma once
11
12#include "codac2_OctaSym.h"
13
14namespace codac2
15{
16 struct OctaSymOp
17 {
18 template<typename X1>
19 static inline std::string str(const X1& x1)
20 {
21 return "sym(" + x1->str() + ")";
22 }
23
24 template<typename X1>
25 static inline std::pair<Index,Index> output_shape([[maybe_unused]] const X1& s1)
26 {
27 return s1->output_shape();
28 }
29
30 static inline IntervalVector fwd(const OctaSym& s, const IntervalVector& x1)
31 {
32 assert((Index)s.size() == x1.size());
33 return s(x1);
34 }
35
36 static inline VectorType fwd_natural(const OctaSym& s, const VectorType& x1)
37 {
38 assert((Index)s.size() == x1.m.size());
39 return {
40 fwd(s, x1.a),
41 x1.def_domain
42 };
43 }
44
45 static inline VectorType fwd_centered(const OctaSym& s, const VectorType& x1)
46 {
47 assert((Index)s.size() == x1.m.size());
48
49 auto da = x1.da;
50 for(size_t i = 0 ; i < s.size() ; i++)
51 da.row(i) = sign(s[i])*x1.da.row(std::abs(s[i])-1);
52
53 return {
54 fwd(s, x1.m),
55 fwd(s, x1.a),
56 da,
57 x1.def_domain
58 };
59 }
60
61 static inline void bwd(const OctaSym& s, const IntervalVector& y, IntervalVector& x1)
62 {
63 assert((Index)s.size() == y.size() && (Index)s.size() == x1.size());
64 x1 &= s.invert()(y);
65 }
66 };
67
68
69 template<>
70 class AnalyticOperationExpr<OctaSymOp,VectorType,VectorType>
71 : public AnalyticExpr<VectorType>, public OperationExprBase<AnalyticExpr<VectorType>>
72 {
73 public:
74
75 AnalyticOperationExpr(const OctaSym& s, const VectorExpr& x1)
76 : OperationExprBase<AnalyticExpr<VectorType>>(x1), _s(s)
77 { }
78
79 std::shared_ptr<ExprBase> copy() const
80 {
81 return std::make_shared<AnalyticOperationExpr<OctaSymOp,VectorType,VectorType>>(*this);
82 }
83
84 void replace_arg(const ExprID& old_arg_id, const std::shared_ptr<ExprBase>& new_expr)
85 {
86 return OperationExprBase<AnalyticExpr<VectorType>>::replace_arg(old_arg_id, new_expr);
87 }
88
89 VectorType fwd_eval(ValuesMap& v, Index total_input_size, bool natural_eval) const
90 {
91 if(natural_eval)
92 return AnalyticExpr<VectorType>::init_value(
93 v, OctaSymOp::fwd_natural(_s, std::get<0>(this->_x)->fwd_eval(v, total_input_size, natural_eval)));
94 else
95 return AnalyticExpr<VectorType>::init_value(
96 v, OctaSymOp::fwd_centered(_s, std::get<0>(this->_x)->fwd_eval(v, total_input_size, natural_eval)));
97 }
98
99 void bwd_eval(ValuesMap& v) const
100 {
101 OctaSymOp::bwd(_s, AnalyticExpr<VectorType>::value(v).a, std::get<0>(this->_x)->value(v).a);
102 std::get<0>(this->_x)->bwd_eval(v);
103 }
104
105 std::pair<Index,Index> output_shape() const {
106 return { _s.size(), 1 };
107 }
108
109 virtual bool belongs_to_args_list(const FunctionArgsList& args) const
110 {
111 return std::get<0>(this->_x)->belongs_to_args_list(args);
112 }
113
114 std::string str(bool in_parentheses = false) const
115 {
116 std::string s = "S"; // user cannot (yet) specify a name for the symmetry
117 return in_parentheses ? "(" + s + ")" : s;
118 }
119
120 virtual bool is_str_leaf() const
121 {
122 return true;
123 }
124
125 protected:
126
127 const OctaSym _s;
128 };
129}
A base class for expressions representing operations with multiple operands.
Definition codac2_ExprBase.h:164
OperationExprBase(std::shared_ptr< X >... x)
Definition codac2_ExprBase.h:176
std::tuple< std::shared_ptr< X >... > _x
Definition codac2_ExprBase.h:258
Definition codac2_OctaSym.h:21
Eigen::Matrix< Interval,-1, 1 > IntervalVector
Alias for a dynamic-size column vector of intervals.
Definition codac2_IntervalVector.h:25
Interval sign(const Interval &x)
Returns .
Definition codac2_Interval_operations_impl.h:279