22 using ValuesMap = std::map<ExprID,std::shared_ptr<AnalyticTypeBase>>;
29 AnalyticExpr<T>& operator=(
const AnalyticExpr<T>& x) =
delete;
31 virtual T fwd_eval(ValuesMap& v, Index total_input_size,
bool natural_eval)
const = 0;
32 virtual void bwd_eval(ValuesMap& v)
const = 0;
33 virtual std::pair<Index,Index> output_shape()
const = 0;
35 T init_value(ValuesMap& v,
const T& x)
const
45 *std::dynamic_pointer_cast<T>(it->second) = x;
46 return *std::dynamic_pointer_cast<T>(it->second);
49 T& value(ValuesMap& v)
const
51 assert(v.find(
unique_id()) != v.end() &&
"argument cannot be found");
52 return *std::dynamic_pointer_cast<T>(v[
unique_id()]);
55 virtual bool belongs_to_args_list(
const FunctionArgsList& args)
const = 0;
56 virtual std::string str(
bool in_parentheses =
false)
const = 0;
57 virtual bool is_str_leaf()
const = 0;
60 template<
typename C,
typename Y,
typename... X>
61 class AnalyticOperationExpr :
public AnalyticExpr<Y>,
public OperationExprBase<AnalyticExpr<X>...>
65 AnalyticOperationExpr(std::shared_ptr<AnalyticExpr<X>>... x)
69 AnalyticOperationExpr(
const AnalyticOperationExpr<C,Y,X...>& e)
73 std::shared_ptr<ExprBase> copy()
const
75 return std::make_shared<AnalyticOperationExpr<C,Y,X...>>(*this);
78 void replace_arg(
const ExprID& old_arg_id,
const std::shared_ptr<ExprBase>& new_expr)
83 Y fwd_eval(ValuesMap& v, Index total_input_size,
bool natural_eval)
const
86 [
this,&v,total_input_size,natural_eval](
auto &&... x)
89 return AnalyticExpr<Y>::init_value(v, C::fwd_natural(x->fwd_eval(v, total_input_size, natural_eval)...));
92 return AnalyticExpr<Y>::init_value(v, C::fwd_centered(x->fwd_eval(v, total_input_size, natural_eval)...));
97 void bwd_eval(ValuesMap& v)
const
99 auto y = AnalyticExpr<Y>::value(v);
101 std::apply([&v,y](
auto &&... x)
103 C::bwd(y.a, x->value(v).a...);
106 std::apply([&v](
auto &&... x)
108 (x->bwd_eval(v), ...);
112 virtual std::string str(
bool in_parentheses =
false)
const
115 std::apply([&s](
auto &&... x)
119 return in_parentheses ?
"(" + s +
")" : s;
122 virtual bool is_str_leaf()
const
127 std::pair<Index,Index> output_shape()
const
129 std::pair<Index,Index> s;
130 std::apply([&s](
auto &&... x)
132 s = C::output_shape(x...);
137 virtual bool belongs_to_args_list(
const FunctionArgsList& args)
const
141 std::apply([&b,args](
auto &&... x)
143 ((b &= x->belongs_to_args_list(args)), ...);
Abstract base class for representing an expression.
Definition codac2_ExprBase.h:85
const ExprID & unique_id() const
Returns the unique identifier of the expression.
A base class for expressions representing operations with multiple operands.
Definition codac2_ExprBase.h:164
OperationExprBase(std::shared_ptr< X >... x)
Definition codac2_ExprBase.h:176
std::tuple< std::shared_ptr< X >... > _x
Definition codac2_ExprBase.h:258