22 using ValuesMap = std::map<ExprID,std::shared_ptr<AnalyticTypeBase>>;
29 virtual T fwd_eval(ValuesMap& v, Index total_input_size,
bool natural_eval)
const = 0;
30 virtual void bwd_eval(ValuesMap& v)
const = 0;
31 virtual std::pair<Index,Index> output_shape()
const = 0;
33 T init_value(ValuesMap& v,
const T& x)
const
43 *std::dynamic_pointer_cast<T>(it->second) = x;
44 return *std::dynamic_pointer_cast<T>(it->second);
47 T& value(ValuesMap& v)
const
49 assert(v.find(
unique_id()) != v.end() &&
"argument cannot be found");
50 return *std::dynamic_pointer_cast<T>(v[
unique_id()]);
53 virtual bool belongs_to_args_list(
const FunctionArgsList& args)
const = 0;
54 virtual std::string str(
bool in_parentheses =
false)
const = 0;
55 virtual bool is_str_leaf()
const = 0;
58 template<
typename C,
typename Y,
typename... X>
59 class AnalyticOperationExpr :
public AnalyticExpr<Y>,
public OperationExprBase<AnalyticExpr<X>...>
63 AnalyticOperationExpr(std::shared_ptr<AnalyticExpr<X>>... x)
67 AnalyticOperationExpr(
const AnalyticOperationExpr<C,Y,X...>& e)
71 std::shared_ptr<ExprBase> copy()
const
73 return std::make_shared<AnalyticOperationExpr<C,Y,X...>>(*this);
76 void replace_arg(
const ExprID& old_arg_id,
const std::shared_ptr<ExprBase>& new_expr)
81 Y fwd_eval(ValuesMap& v, Index total_input_size,
bool natural_eval)
const
84 [
this,&v,total_input_size,natural_eval](
auto &&... x)
87 return AnalyticExpr<Y>::init_value(v, C::fwd_natural(x->fwd_eval(v, total_input_size, natural_eval)...));
90 return AnalyticExpr<Y>::init_value(v, C::fwd_centered(x->fwd_eval(v, total_input_size, natural_eval)...));
95 void bwd_eval(ValuesMap& v)
const
97 auto y = AnalyticExpr<Y>::value(v);
99 std::apply([&v,y](
auto &&... x)
101 C::bwd(y.a, x->value(v).a...);
104 std::apply([&v](
auto &&... x)
106 (x->bwd_eval(v), ...);
110 virtual std::string str(
bool in_parentheses =
false)
const
112 std::string s = std::apply([](
auto &&... x) {
115 return in_parentheses ?
"(" + s +
")" : s;
118 virtual bool is_str_leaf()
const
123 std::pair<Index,Index> output_shape()
const
125 std::pair<Index,Index> s;
126 std::apply([&s](
auto &&... x)
128 s = C::output_shape(x...);
133 virtual bool belongs_to_args_list(
const FunctionArgsList& args)
const
137 std::apply([&b,args](
auto &&... x)
139 ((b &= x->belongs_to_args_list(args)), ...);
Abstract base class for representing an expression.
Definition codac2_ExprBase.h:85
const ExprID & unique_id() const
Returns the unique identifier of the expression.
A base class for expressions representing operations with multiple operands.
Definition codac2_ExprBase.h:164
OperationExprBase(std::shared_ptr< X >... x)
Definition codac2_ExprBase.h:176
std::tuple< std::shared_ptr< X >... > _x
Definition codac2_ExprBase.h:258