21#include "codac2_vec.h"
32 inline EvalMode operator&(EvalMode a, EvalMode b)
33 {
return static_cast<EvalMode
>(
static_cast<int>(a) &
static_cast<int>(b)); }
35 inline EvalMode operator|(EvalMode a, EvalMode b)
36 {
return static_cast<EvalMode
>(
static_cast<int>(a) |
static_cast<int>(b)); }
38 struct ScalarExprList :
public AnalyticExprWrapper<VectorType>
41 template<
typename... S>
42 requires (std::is_same_v<typename ExprType<S>::Type,ScalarType> && ...)
43 ScalarExprList(
const S&... y)
44 : AnalyticExprWrapper<VectorType>(vec(y...))
55 requires std::is_base_of_v<AnalyticTypeBase,T>
56 class AnalyticFunction :
public FunctionBase<AnalyticExpr<T>>
60 AnalyticFunction(
const FunctionArgsList&
args,
const ScalarExprList& y)
61 requires(std::is_same_v<T,VectorType>)
64 assert_release(y->belongs_to_args_list(this->args()) &&
65 "Invalid argument: variable not present in input arguments");
69 AnalyticFunction(
const FunctionArgsList&
args,
const AnalyticExprWrapper<T>& y)
72 assert_release(y->belongs_to_args_list(this->args()) &&
73 "Invalid argument: variable not present in input arguments");
77 AnalyticFunction(
const AnalyticFunction<T>& f)
81 template<
typename... X>
82 AnalyticExprWrapper<T> operator()(
const X&... x)
const
87 template<
typename... Args>
88 auto real_eval(
const Args&... x)
const
90 return eval(x...).mid();
93 template<
typename... Args>
94 typename T::Domain eval(
const EvalMode& m,
const Args&... x)
const
96 check_valid_inputs(x...);
100 case EvalMode::NATURAL:
102 return eval_<true>(x...).a;
105 case EvalMode::CENTERED:
107 auto x_ = eval_<false>(x...);
108 auto flatten_x = IntervalVector(cart_prod(x...));
109 assert(x_.da.rows() == x_.a.size() && x_.da.cols() == flatten_x.size());
111 if constexpr(std::is_same_v<T,ScalarType>)
112 return x_.m + (x_.da*(flatten_x-flatten_x.mid()))[0];
114 else if constexpr(std::is_same_v<T,VectorType>)
115 return x_.m + (x_.da*(flatten_x-flatten_x.mid())).col(0);
119 static_assert(std::is_same_v<T,MatrixType>);
120 return x_.m + (x_.da*(flatten_x-flatten_x.mid()))
121 .reshaped(x_.m.rows(), x_.m.cols());
125 case EvalMode::DEFAULT:
128 auto x_ = eval_<false>(x...);
133 return eval(EvalMode::NATURAL, x...);
137 auto flatten_x = IntervalVector(cart_prod(x...));
139 if constexpr(std::is_same_v<T,ScalarType>)
140 return x_.a & (x_.m + (x_.da*(flatten_x-flatten_x.mid()))[0]);
142 else if constexpr(std::is_same_v<T,VectorType>)
144 assert(x_.da.rows() == x_.a.size() && x_.da.cols() == flatten_x.size());
145 return x_.a & (x_.m + (x_.da*(flatten_x-flatten_x.mid())).col(0));
150 static_assert(std::is_same_v<T,MatrixType>);
151 assert(x_.da.rows() == x_.a.size() && x_.da.cols() == flatten_x.size());
152 return x_.a & (x_.m +(x_.da*(flatten_x-flatten_x.mid()))
153 .reshaped(x_.m.rows(),x_.m.cols()));
160 template<
typename... Args>
161 auto diff(
const Args&... x)
const
163 check_valid_inputs(x...);
164 return eval_<false>(x...).da;
167 template<
typename... Args>
168 typename T::Domain eval(
const Args&... x)
const
170 return eval(EvalMode::NATURAL | EvalMode::CENTERED, x...);
173 template<
typename... Args>
174 auto traj_eval(
const SampledTraj<Args>&... x)
const
176 SampledTraj<typename T::Scalar> y;
177 for(
const auto& [ti,xi] : std::get<0>(std::tie(x...)))
178 y.set(this->real_eval(x(ti)...),ti);
182 template<
typename... Args>
183 auto tube_eval(
const SlicedTube<Args>&... x)
const
185 auto tdomain = std::get<0>(std::tie(x...)).tdomain();
187 SlicedTube<typename T::Domain> y(
188 tdomain, (
typename T::Domain)(this->output_size())
191 for(
auto it = tdomain->begin() ; it != tdomain->end() ; it++)
192 y(it)->codomain() = this->eval(x(it)->codomain()...);
197 Index output_size()
const
199 if constexpr(std::is_same_v<T,ScalarType>)
203 std::pair<Index,Index> oshape = output_shape();
204 return oshape.first * oshape.second;
208 std::pair<Index,Index> output_shape()
const
210 if constexpr(std::is_same_v<T,ScalarType>)
212 else return this->
expr()->output_shape();
215 friend std::ostream& operator<<(std::ostream& os, [[maybe_unused]]
const AnalyticFunction<T>& f)
218 for(
size_t i = 0 ; i < f.args().size() ; i++)
219 os << (i!=0 ?
"," :
"") << f.args()[i]->name();
220 os <<
") ↦ " << f.expr()->str();
232 template<
typename... Args>
233 void fill_from_args(ValuesMap& v,
const Args&... x)
const
236 (add_value_to_arg_map(v, x, i++), ...);
239 template<
typename... Args>
240 void intersect_from_args(
const ValuesMap& v, Args&... x)
const
243 (intersect_value_from_arg_map(v, x, i++), ...);
249 void add_value_to_arg_map(ValuesMap& v,
const D& x, Index i)
const
251 assert(i >= 0 && i < (Index)this->
args().size());
252 assert_release(size_of(x) == this->
args()[i]->size() &&
"provided arguments do not match function inputs");
254 using D_TYPE =
typename ExprType<D>::Type;
256 IntervalMatrix d = IntervalMatrix::zero(size_of(x), this->
args().total_size());
259 for(Index j = 0 ; j < i ; j++)
260 p += this->
args()[j]->size();
262 for(Index k = p ; k < p+size_of(x) ; k++)
265 v[this->
args()[i]->unique_id()] =
266 std::make_shared<D_TYPE>(
typename D_TYPE::Domain(x).mid(), x, d,
true);
270 void intersect_value_from_arg_map(
const ValuesMap& v, D& x, Index i)
const
272 assert(v.find(this->args()[i]->unique_id()) != v.end() &&
"argument cannot be found");
273 x &= std::dynamic_pointer_cast<typename ExprType<D>::Type>(v.at(this->
args()[i]->unique_id()))->a;
276 template<
bool NATURAL_EVAL,
typename... Args>
277 auto eval_(
const Args&... x)
const
281 if constexpr(
sizeof...(Args) == 0)
282 return this->
expr()->fwd_eval(v, 0, NATURAL_EVAL);
286 fill_from_args(v, x...);
287 return this->
expr()->fwd_eval(v, cart_prod(x...).size(), NATURAL_EVAL);
291 template<
typename... Args>
292 void check_valid_inputs(
const Args&... x)
const
294 [[maybe_unused]] Index n = 0;
295 ((n += size_of(x)), ...);
297 assert_release(this->
_args.total_size() == n &&
298 "Invalid arguments: wrong number of input arguments");
301 inline void update_var_names()
303 for(
const auto& v : this->
_args)
305 this->
_y->replace_arg(v->unique_id(), std::dynamic_pointer_cast<ExprBase>(v));
309 AnalyticFunction(
const FunctionArgsList&, std::initializer_list<ScalarExpr>) ->
310 AnalyticFunction<VectorType>;
314 AnalyticFunction<typename ExprType<T>::Type>;
A container class to manage a collection of function arguments.
Definition codac2_FunctionArgsList.h:25
A base class for functions (either analytic functions, or set functions).
Definition codac2_FunctionBase.h:41
const FunctionArgsList & args() const
Definition codac2_FunctionBase.h:93
const FunctionArgsList _args
Definition codac2_FunctionBase.h:218
FunctionBase(const std::vector< std::reference_wrapper< VarBase > > &args, const std::shared_ptr< AnalyticExpr< T > > &y)
Definition codac2_FunctionBase.h:52
const std::shared_ptr< AnalyticExpr< T > > _y
Definition codac2_FunctionBase.h:217
const std::shared_ptr< AnalyticExpr< T > > & expr() const
Definition codac2_FunctionBase.h:113