22 using ValuesMap = std::map<ExprID,std::shared_ptr<AnalyticTypeBase>>;
29 virtual T fwd_eval(ValuesMap& v, Index total_input_size,
bool natural_eval)
const = 0;
30 virtual void bwd_eval(ValuesMap& v)
const = 0;
31 virtual std::pair<Index,Index> output_shape()
const = 0;
33 const T& init_value(ValuesMap& v,
const T& x)
const
39 p = std::make_shared<T>(x);
40 return *
static_cast<T*
>(p.get());
45 auto pt = std::dynamic_pointer_cast<T>(p);
46 assert(pt &&
"Type mismatch in ValuesMap for this ExprID");
52 T& value(ValuesMap& v)
const
55 assert(it != v.end() &&
"argument cannot be found");
56 auto p = std::dynamic_pointer_cast<T>(it->second);
57 assert(p &&
"Type mismatch in ValuesMap for this ExprID");
61 virtual bool belongs_to_args_list(
const FunctionArgsList& args)
const = 0;
62 virtual std::string str(
bool in_parentheses =
false)
const = 0;
63 virtual bool is_str_leaf()
const = 0;
66 template<
typename C,
typename Y,
typename... X>
67 class AnalyticOperationExpr :
public AnalyticExpr<Y>,
public OperationExprBase<AnalyticExpr<X>...>
71 AnalyticOperationExpr(std::shared_ptr<AnalyticExpr<X>>... x)
75 AnalyticOperationExpr(
const AnalyticOperationExpr<C,Y,X...>& e)
79 std::shared_ptr<ExprBase>
copy()
const
81 return std::make_shared<AnalyticOperationExpr<C,Y,X...>>(*this);
84 void replace_arg(
const ExprID& old_arg_id,
const std::shared_ptr<ExprBase>& new_expr)
89 Y fwd_eval(ValuesMap& v, Index total_input_size,
bool natural_eval)
const
92 [
this,&v,total_input_size,natural_eval](
auto &&... x)
95 return AnalyticExpr<Y>::init_value(v,
96 C::fwd_natural(x->fwd_eval(v, total_input_size, natural_eval)...));
99 return AnalyticExpr<Y>::init_value(v,
100 C::fwd_centered(x->fwd_eval(v, total_input_size, natural_eval)...));
105 void bwd_eval(ValuesMap& v)
const
107 auto y = AnalyticExpr<Y>::value(v);
109 std::apply([&v,y](
auto &&... x)
111 C::bwd(y.a, x->value(v).a...);
114 std::apply([&v](
auto &&... x)
116 (x->bwd_eval(v), ...);
120 virtual std::string str(
bool in_parentheses =
false)
const
122 std::string s = std::apply([](
auto &&... x) {
125 return in_parentheses ?
"(" + s +
")" : s;
128 virtual bool is_str_leaf()
const
133 std::pair<Index,Index> output_shape()
const
135 std::pair<Index,Index> s;
136 std::apply([&s](
auto &&... x)
138 s = C::output_shape(x...);
143 virtual bool belongs_to_args_list(
const FunctionArgsList& args)
const
147 std::apply([&b,&args](
auto &&... x)
149 ((b &= x->belongs_to_args_list(args)), ...);
Abstract base class for representing an expression.
Definition codac2_ExprBase.h:86
const ExprID & unique_id() const
Returns the unique identifier of the expression.
virtual std::shared_ptr< ExprBase > copy() const =0
Creates a copy of the current expression.
A base class for expressions representing operations with multiple operands.
Definition codac2_ExprBase.h:165
OperationExprBase(std::shared_ptr< X >... x)
Definition codac2_ExprBase.h:177
std::tuple< std::shared_ptr< X >... > _x
Definition codac2_ExprBase.h:259
Definition codac2_OctaSym.h:21