codac 2.0.0
Loading...
Searching...
No Matches
codac2_AnalyticFunction.h
Go to the documentation of this file.
1
9
10#pragma once
11
12#include <map>
13#include "codac2_AnalyticExpr.h"
14#include "codac2_Domain.h"
16#include "codac2_FunctionBase.h"
19#include "codac2_operators.h"
20#include "codac2_cart_prod.h"
21#include "codac2_vec.h"
22
23namespace codac2
24{
25 enum class EvalMode
26 {
27 NATURAL = 0x01,
28 CENTERED = 0x02,
29 DEFAULT = 0x03 // corresponds to (NATURAL|CENTERED)
30 };
31
32 inline EvalMode operator&(EvalMode a, EvalMode b)
33 { return static_cast<EvalMode>(static_cast<int>(a) & static_cast<int>(b)); }
34
35 inline EvalMode operator|(EvalMode a, EvalMode b)
36 { return static_cast<EvalMode>(static_cast<int>(a) | static_cast<int>(b)); }
37
38 struct ScalarExprList : public AnalyticExprWrapper<VectorType>
39 {
40 // Mainly used to take advantage of initializer lists in AnalyticFunction constructors.
41 template<typename... S>
42 requires (std::is_same_v<typename ExprType<S>::Type,ScalarType> && ...)
43 ScalarExprList(const S&... y)
44 : AnalyticExprWrapper<VectorType>(vec(y...))
45 { }
46 };
47
48 template<typename T>
49 class SampledTraj;
50
51 template<typename T>
52 class SlicedTube;
53
54 template<typename T>
55 requires std::is_base_of_v<AnalyticTypeBase,T>
56 class AnalyticFunction : public FunctionBase<AnalyticExpr<T>>
57 {
58 public:
59
60 AnalyticFunction(const FunctionArgsList& args, const ScalarExprList& y)
61 requires(std::is_same_v<T,VectorType>)
62 : FunctionBase<AnalyticExpr<T>>(args, y)
63 {
64 assert_release(y->belongs_to_args_list(this->args()) &&
65 "Invalid argument: variable not present in input arguments");
66 update_var_names();
67 }
68
69 AnalyticFunction(const FunctionArgsList& args, const AnalyticExprWrapper<T>& y)
70 : FunctionBase<AnalyticExpr<T>>(args, y)
71 {
72 assert_release(y->belongs_to_args_list(this->args()) &&
73 "Invalid argument: variable not present in input arguments");
74 update_var_names();
75 }
76
77 AnalyticFunction(const AnalyticFunction<T>& f)
78 : FunctionBase<AnalyticExpr<T>>(f)
79 { }
80
81 template<typename... X>
82 AnalyticExprWrapper<T> operator()(const X&... x) const
83 {
84 return { this->FunctionBase<AnalyticExpr<T>>::operator()(x...) };
85 }
86
87 template<typename... Args>
88 auto real_eval(const Args&... x) const
89 {
90 return eval(x...).mid();
91 }
92
93 template<typename... Args>
94 typename T::Domain eval(const EvalMode& m, const Args&... x) const
95 {
96 check_valid_inputs(x...);
97
98 switch(m)
99 {
100 case EvalMode::NATURAL:
101 {
102 return eval_<true>(x...).a;
103 }
104
105 case EvalMode::CENTERED:
106 {
107 auto x_ = eval_<false>(x...);
108 auto flatten_x = IntervalVector(cart_prod(x...));
109 assert(x_.da.rows() == x_.a.size() && x_.da.cols() == flatten_x.size());
110
111 if constexpr(std::is_same_v<T,ScalarType>)
112 return x_.m + (x_.da*(flatten_x-flatten_x.mid()))[0];
113
114 else if constexpr(std::is_same_v<T,VectorType>)
115 return x_.m + (x_.da*(flatten_x-flatten_x.mid())).col(0);
116
117 else
118 {
119 static_assert(std::is_same_v<T,MatrixType>);
120 return x_.m + (x_.da*(flatten_x-flatten_x.mid()))
121 .reshaped(x_.m.rows(), x_.m.cols());
122 }
123 }
124
125 case EvalMode::DEFAULT:
126 default:
127 {
128 auto x_ = eval_<false>(x...);
129
130 // If the centered form is not available for this expression...
131 if(x_.da.size() == 0 // .. because some parts have not yet been implemented,
132 || !x_.def_domain) // .. or due to restrictions in the derivative definition domain
133 return eval(EvalMode::NATURAL, x...);
134
135 else
136 {
137 auto flatten_x = IntervalVector(cart_prod(x...));
138
139 if constexpr(std::is_same_v<T,ScalarType>)
140 return x_.a & (x_.m + (x_.da*(flatten_x-flatten_x.mid()))[0]);
141
142 else if constexpr(std::is_same_v<T,VectorType>)
143 {
144 assert(x_.da.rows() == x_.a.size() && x_.da.cols() == flatten_x.size());
145 return x_.a & (x_.m + (x_.da*(flatten_x-flatten_x.mid())).col(0));
146 }
147
148 else
149 {
150 static_assert(std::is_same_v<T,MatrixType>);
151 assert(x_.da.rows() == x_.a.size() && x_.da.cols() == flatten_x.size());
152 return x_.a & (x_.m +(x_.da*(flatten_x-flatten_x.mid()))
153 .reshaped(x_.m.rows(),x_.m.cols()));
154 }
155 }
156 }
157 }
158 }
159
160 template<typename... Args>
161 auto diff(const Args&... x) const
162 {
163 check_valid_inputs(x...);
164 return eval_<false>(x...).da;
165 }
166
167 template<typename... Args>
168 typename T::Domain eval(const Args&... x) const
169 {
170 return eval(EvalMode::NATURAL | EvalMode::CENTERED, x...);
171 }
172
173 template<typename... Args>
174 auto traj_eval(const SampledTraj<Args>&... x) const
175 {
176 SampledTraj<typename T::Scalar> y;
177 for(const auto& [ti,xi] : std::get<0>(std::tie(x...)))
178 y.set(this->real_eval(x(ti)...),ti);
179 return y;
180 }
181
182 template<typename... Args>
183 auto tube_eval(const SlicedTube<Args>&... x) const
184 {
185 auto tdomain = std::get<0>(std::tie(x...)).tdomain();
186
187 SlicedTube<typename T::Domain> y(
188 tdomain, (typename T::Domain)(this->output_size())
189 );
190
191 for(auto it = tdomain->begin() ; it != tdomain->end() ; it++)
192 y(it)->codomain() = this->eval(x(it)->codomain()...);
193
194 return y;
195 }
196
197 Index output_size() const
198 {
199 if constexpr(std::is_same_v<T,ScalarType>)
200 return 1;
201
202 else {
203 std::pair<Index,Index> oshape = output_shape();
204 return oshape.first * oshape.second;
205 }
206 }
207
208 std::pair<Index,Index> output_shape() const
209 {
210 if constexpr(std::is_same_v<T,ScalarType>)
211 return {1,1};
212 else return this->expr()->output_shape();
213 }
214
215 friend std::ostream& operator<<(std::ostream& os, [[maybe_unused]] const AnalyticFunction<T>& f)
216 {
217 os << "(";
218 for(size_t i = 0 ; i < f.args().size() ; i++)
219 os << (i!=0 ? "," : "") << f.args()[i]->name();
220 os << ") ↦ " << f.expr()->str();
221 return os;
222 }
223
224 // not working with Clang: template<typename Y, typename... X>
225 // not working with Clang: requires (sizeof...(X) > 0)
226 // not working with Clang: friend class CtcInverse;
227
228 // So, the following methods are temporarily public
229
230 // protected:
231
232 template<typename... Args>
233 void fill_from_args(ValuesMap& v, const Args&... x) const
234 {
235 Index i = 0;
236 (add_value_to_arg_map(v, x, i++), ...);
237 }
238
239 template<typename... Args>
240 void intersect_from_args(const ValuesMap& v, Args&... x) const
241 {
242 Index i = 0;
243 (intersect_value_from_arg_map(v, x, i++), ...);
244 }
245
246 protected:
247
248 template<typename D>
249 void add_value_to_arg_map(ValuesMap& v, const D& x, Index i) const
250 {
251 assert(i >= 0 && i < (Index)this->args().size());
252 assert_release(size_of(x) == this->args()[i]->size() && "provided arguments do not match function inputs");
253
254 using D_TYPE = typename ExprType<D>::Type;
255
256 IntervalMatrix d = IntervalMatrix::zero(size_of(x), this->args().total_size());
257
258 Index p = 0;
259 for(Index j = 0 ; j < i ; j++)
260 p += this->args()[j]->size();
261
262 for(Index k = p ; k < p+size_of(x) ; k++)
263 d(k-p,k) = 1.;
264
265 v[this->args()[i]->unique_id()] =
266 std::make_shared<D_TYPE>(typename D_TYPE::Domain(x).mid(), x, d, true);
267 }
268
269 template<typename D>
270 void intersect_value_from_arg_map(const ValuesMap& v, D& x, Index i) const
271 {
272 assert(v.find(this->args()[i]->unique_id()) != v.end() && "argument cannot be found");
273 x &= std::dynamic_pointer_cast<typename ExprType<D>::Type>(v.at(this->args()[i]->unique_id()))->a;
274 }
275
276 template<bool NATURAL_EVAL,typename... Args>
277 auto eval_(const Args&... x) const
278 {
279 ValuesMap v;
280
281 if constexpr(sizeof...(Args) == 0)
282 return this->expr()->fwd_eval(v, 0, NATURAL_EVAL);
283
284 else
285 {
286 fill_from_args(v, x...);
287 return this->expr()->fwd_eval(v, cart_prod(x...).size(), NATURAL_EVAL); // todo: improve size computation
288 }
289 }
290
291 template<typename... Args>
292 void check_valid_inputs(const Args&... x) const
293 {
294 [[maybe_unused]] Index n = 0;
295 ((n += size_of(x)), ...);
296
297 assert_release(this->_args.total_size() == n &&
298 "Invalid arguments: wrong number of input arguments");
299 }
300
301 inline void update_var_names()
302 {
303 for(const auto& v : this->_args) // variable names are automatically computed in FunctionArgsList,
304 // so we propagate them to the expression
305 this->_y->replace_arg(v->unique_id(), std::dynamic_pointer_cast<ExprBase>(v));
306 }
307 };
308
309 AnalyticFunction(const FunctionArgsList&, std::initializer_list<ScalarExpr>) ->
310 AnalyticFunction<VectorType>;
311
312 template<typename T>
313 AnalyticFunction(const FunctionArgsList&, const T&) ->
314 AnalyticFunction<typename ExprType<T>::Type>;
315
316}
A container class to manage a collection of function arguments.
Definition codac2_FunctionArgsList.h:25
A base class for functions (either analytic functions, or set functions).
Definition codac2_FunctionBase.h:41
const FunctionArgsList & args() const
Definition codac2_FunctionBase.h:93
const FunctionArgsList _args
Definition codac2_FunctionBase.h:218
FunctionBase(const std::vector< std::reference_wrapper< VarBase > > &args, const std::shared_ptr< AnalyticExpr< T > > &y)
Definition codac2_FunctionBase.h:52
const std::shared_ptr< AnalyticExpr< T > > _y
Definition codac2_FunctionBase.h:217
const std::shared_ptr< AnalyticExpr< T > > & expr() const
Definition codac2_FunctionBase.h:113