codac 1.5.6
Loading...
Searching...
No Matches
codac2_AnalyticExpr.h
Go to the documentation of this file.
1
9
10#pragma once
11
12#include <map>
13#include <memory>
14#include <utility>
15#include "codac2_ExprBase.h"
16#include "codac2_Domain.h"
18#include "codac2_AnalyticType.h"
19
20namespace codac2
21{
22 using ValuesMap = std::map<ExprID,std::shared_ptr<AnalyticTypeBase>>;
23
24 template<typename T>
25 class AnalyticExpr : public ExprBase
26 {
27 public:
28
29 virtual T fwd_eval(ValuesMap& v, Index total_input_size, bool natural_eval) const = 0;
30 virtual void bwd_eval(ValuesMap& v) const = 0;
31 virtual std::pair<Index,Index> output_shape() const = 0;
32
33 T init_value(ValuesMap& v, const T& x) const
34 {
35 auto it = v.find(unique_id());
36
37 if(it == v.end())
38 {
39 v[unique_id()] = std::make_shared<T>(x);
40 return x;
41 }
42
43 *std::dynamic_pointer_cast<T>(it->second) = x;
44 return *std::dynamic_pointer_cast<T>(it->second);
45 }
46
47 T& value(ValuesMap& v) const
48 {
49 assert(v.find(unique_id()) != v.end() && "argument cannot be found");
50 return *std::dynamic_pointer_cast<T>(v[unique_id()]);
51 }
52
53 virtual bool belongs_to_args_list(const FunctionArgsList& args) const = 0;
54 virtual std::string str(bool in_parentheses = false) const = 0;
55 virtual bool is_str_leaf() const = 0;
56 };
57
58 template<typename C,typename Y,typename... X>
59 class AnalyticOperationExpr : public AnalyticExpr<Y>, public OperationExprBase<AnalyticExpr<X>...>
60 {
61 public:
62
63 AnalyticOperationExpr(std::shared_ptr<AnalyticExpr<X>>... x)
64 : OperationExprBase<AnalyticExpr<X>...>(x...)
65 { }
66
67 AnalyticOperationExpr(const AnalyticOperationExpr<C,Y,X...>& e)
68 : OperationExprBase<AnalyticExpr<X>...>(e)
69 { }
70
71 std::shared_ptr<ExprBase> copy() const
72 {
73 return std::make_shared<AnalyticOperationExpr<C,Y,X...>>(*this);
74 }
75
76 void replace_arg(const ExprID& old_arg_id, const std::shared_ptr<ExprBase>& new_expr)
77 {
78 return OperationExprBase<AnalyticExpr<X>...>::replace_arg(old_arg_id, new_expr);
79 }
80
81 Y fwd_eval(ValuesMap& v, Index total_input_size, bool natural_eval) const
82 {
83 return std::apply(
84 [this,&v,total_input_size,natural_eval](auto &&... x)
85 {
86 if(natural_eval)
87 return AnalyticExpr<Y>::init_value(v, C::fwd_natural(x->fwd_eval(v, total_input_size, natural_eval)...));
88
89 else
90 return AnalyticExpr<Y>::init_value(v, C::fwd_centered(x->fwd_eval(v, total_input_size, natural_eval)...));
91 },
92 this->_x);
93 }
94
95 void bwd_eval(ValuesMap& v) const
96 {
97 auto y = AnalyticExpr<Y>::value(v);
98
99 std::apply([&v,y](auto &&... x)
100 {
101 C::bwd(y.a, x->value(v).a...);
102 }, this->_x);
103
104 std::apply([&v](auto &&... x)
105 {
106 (x->bwd_eval(v), ...);
107 }, this->_x);
108 }
109
110 virtual std::string str(bool in_parentheses = false) const
111 {
112 std::string s = std::apply([](auto &&... x) {
113 return C::str(x...);
114 }, this->_x);
115 return in_parentheses ? "(" + s + ")" : s;
116 }
117
118 virtual bool is_str_leaf() const
119 {
120 return false;
121 }
122
123 std::pair<Index,Index> output_shape() const
124 {
125 std::pair<Index,Index> s;
126 std::apply([&s](auto &&... x)
127 {
128 s = C::output_shape(x...);
129 }, this->_x);
130 return s;
131 }
132
133 virtual bool belongs_to_args_list(const FunctionArgsList& args) const
134 {
135 bool b = true;
136
137 std::apply([&b,args](auto &&... x)
138 {
139 ((b &= x->belongs_to_args_list(args)), ...);
140 }, this->_x);
141
142 return b;
143 }
144 };
145}
Abstract base class for representing an expression.
Definition codac2_ExprBase.h:85
const ExprID & unique_id() const
Returns the unique identifier of the expression.
A base class for expressions representing operations with multiple operands.
Definition codac2_ExprBase.h:164
OperationExprBase(std::shared_ptr< X >... x)
Definition codac2_ExprBase.h:176
std::tuple< std::shared_ptr< X >... > _x
Definition codac2_ExprBase.h:258