codac 2.0.0
Loading...
Searching...
No Matches
codac2_AnalyticExpr.h
Go to the documentation of this file.
1
9
10#pragma once
11
12#include <map>
13#include <memory>
14#include <utility>
15#include "codac2_ExprBase.h"
16#include "codac2_Domain.h"
18#include "codac2_AnalyticType.h"
19
20namespace codac2
21{
22 using ValuesMap = std::map<ExprID,std::shared_ptr<AnalyticTypeBase>>;
23
24 template<typename T>
25 class AnalyticExpr : public ExprBase
26 {
27 public:
28
29 AnalyticExpr<T>& operator=(const AnalyticExpr<T>& x) = delete;
30
31 virtual T fwd_eval(ValuesMap& v, Index total_input_size, bool natural_eval) const = 0;
32 virtual void bwd_eval(ValuesMap& v) const = 0;
33 virtual std::pair<Index,Index> output_shape() const = 0;
34
35 T init_value(ValuesMap& v, const T& x) const
36 {
37 auto it = v.find(unique_id());
38
39 if(it == v.end())
40 {
41 v[unique_id()] = std::make_shared<T>(x);
42 return x;
43 }
44
45 *std::dynamic_pointer_cast<T>(it->second) = x;
46 return *std::dynamic_pointer_cast<T>(it->second);
47 }
48
49 T& value(ValuesMap& v) const
50 {
51 assert(v.find(unique_id()) != v.end() && "argument cannot be found");
52 return *std::dynamic_pointer_cast<T>(v[unique_id()]);
53 }
54
55 virtual bool belongs_to_args_list(const FunctionArgsList& args) const = 0;
56 virtual std::string str(bool in_parentheses = false) const = 0;
57 virtual bool is_str_leaf() const = 0;
58 };
59
60 template<typename C,typename Y,typename... X>
61 class AnalyticOperationExpr : public AnalyticExpr<Y>, public OperationExprBase<AnalyticExpr<X>...>
62 {
63 public:
64
65 AnalyticOperationExpr(std::shared_ptr<AnalyticExpr<X>>... x)
66 : OperationExprBase<AnalyticExpr<X>...>(x...)
67 { }
68
69 AnalyticOperationExpr(const AnalyticOperationExpr<C,Y,X...>& e)
70 : OperationExprBase<AnalyticExpr<X>...>(e)
71 { }
72
73 std::shared_ptr<ExprBase> copy() const
74 {
75 return std::make_shared<AnalyticOperationExpr<C,Y,X...>>(*this);
76 }
77
78 void replace_arg(const ExprID& old_arg_id, const std::shared_ptr<ExprBase>& new_expr)
79 {
80 return OperationExprBase<AnalyticExpr<X>...>::replace_arg(old_arg_id, new_expr);
81 }
82
83 Y fwd_eval(ValuesMap& v, Index total_input_size, bool natural_eval) const
84 {
85 return std::apply(
86 [this,&v,total_input_size,natural_eval](auto &&... x)
87 {
88 if(natural_eval)
89 return AnalyticExpr<Y>::init_value(v, C::fwd_natural(x->fwd_eval(v, total_input_size, natural_eval)...));
90
91 else
92 return AnalyticExpr<Y>::init_value(v, C::fwd_centered(x->fwd_eval(v, total_input_size, natural_eval)...));
93 },
94 this->_x);
95 }
96
97 void bwd_eval(ValuesMap& v) const
98 {
99 auto y = AnalyticExpr<Y>::value(v);
100
101 std::apply([&v,y](auto &&... x)
102 {
103 C::bwd(y.a, x->value(v).a...);
104 }, this->_x);
105
106 std::apply([&v](auto &&... x)
107 {
108 (x->bwd_eval(v), ...);
109 }, this->_x);
110 }
111
112 virtual std::string str(bool in_parentheses = false) const
113 {
114 std::string s;
115 std::apply([&s](auto &&... x)
116 {
117 s = C::str(x...);
118 }, this->_x);
119 return in_parentheses ? "(" + s + ")" : s;
120 }
121
122 virtual bool is_str_leaf() const
123 {
124 return false;
125 }
126
127 std::pair<Index,Index> output_shape() const
128 {
129 std::pair<Index,Index> s;
130 std::apply([&s](auto &&... x)
131 {
132 s = C::output_shape(x...);
133 }, this->_x);
134 return s;
135 }
136
137 virtual bool belongs_to_args_list(const FunctionArgsList& args) const
138 {
139 bool b = true;
140
141 std::apply([&b,args](auto &&... x)
142 {
143 ((b &= x->belongs_to_args_list(args)), ...);
144 }, this->_x);
145
146 return b;
147 }
148 };
149}
Abstract base class for representing an expression.
Definition codac2_ExprBase.h:85
const ExprID & unique_id() const
Returns the unique identifier of the expression.
A base class for expressions representing operations with multiple operands.
Definition codac2_ExprBase.h:164
OperationExprBase(std::shared_ptr< X >... x)
Definition codac2_ExprBase.h:176
std::tuple< std::shared_ptr< X >... > _x
Definition codac2_ExprBase.h:258