codac 2.0.0
Loading...
Searching...
No Matches
codac2_AnalyticExpr.h
Go to the documentation of this file.
1
9
10#pragma once
11
12#include <map>
13#include <memory>
14#include <utility>
15#include "codac2_ExprBase.h"
16#include "codac2_Domain.h"
18#include "codac2_AnalyticType.h"
19
20namespace codac2
21{
22 using ValuesMap = std::map<ExprID,std::shared_ptr<AnalyticTypeBase>>;
23
24 template<typename T>
25 class AnalyticExpr : public ExprBase
26 {
27 public:
28
29 virtual T fwd_eval(ValuesMap& v, Index total_input_size, bool natural_eval) const = 0;
30 virtual void bwd_eval(ValuesMap& v) const = 0;
31 virtual std::pair<Index,Index> output_shape() const = 0;
32
33 const T& init_value(ValuesMap& v, const T& x) const
34 {
35 auto& p = v[unique_id()];
36
37 if(!p)
38 {
39 p = std::make_shared<T>(x);
40 return *static_cast<T*>(p.get());
41 }
42
43 else
44 {
45 auto pt = std::dynamic_pointer_cast<T>(p);
46 assert(pt && "Type mismatch in ValuesMap for this ExprID");
47 *pt = x;
48 return *pt;
49 }
50 }
51
52 T& value(ValuesMap& v) const
53 {
54 auto it = v.find(unique_id());
55 assert(it != v.end() && "argument cannot be found");
56 auto p = std::dynamic_pointer_cast<T>(it->second);
57 assert(p && "Type mismatch in ValuesMap for this ExprID");
58 return *p;
59 }
60
61 virtual bool belongs_to_args_list(const FunctionArgsList& args) const = 0;
62 virtual std::string str(bool in_parentheses = false) const = 0;
63 virtual bool is_str_leaf() const = 0;
64 };
65
66 template<typename C,typename Y,typename... X>
67 class AnalyticOperationExpr : public AnalyticExpr<Y>, public OperationExprBase<AnalyticExpr<X>...>
68 {
69 public:
70
71 AnalyticOperationExpr(std::shared_ptr<AnalyticExpr<X>>... x)
72 : OperationExprBase<AnalyticExpr<X>...>(x...)
73 { }
74
75 AnalyticOperationExpr(const AnalyticOperationExpr<C,Y,X...>& e)
76 : OperationExprBase<AnalyticExpr<X>...>(e)
77 { }
78
79 std::shared_ptr<ExprBase> copy() const
80 {
81 return std::make_shared<AnalyticOperationExpr<C,Y,X...>>(*this);
82 }
83
84 void replace_arg(const ExprID& old_arg_id, const std::shared_ptr<ExprBase>& new_expr)
85 {
86 return OperationExprBase<AnalyticExpr<X>...>::replace_arg(old_arg_id, new_expr);
87 }
88
89 Y fwd_eval(ValuesMap& v, Index total_input_size, bool natural_eval) const
90 {
91 return std::apply(
92 [this,&v,total_input_size,natural_eval](auto &&... x)
93 {
94 if(natural_eval)
95 return AnalyticExpr<Y>::init_value(v,
96 C::fwd_natural(x->fwd_eval(v, total_input_size, natural_eval)...));
97
98 else
99 return AnalyticExpr<Y>::init_value(v,
100 C::fwd_centered(x->fwd_eval(v, total_input_size, natural_eval)...));
101 },
102 this->_x);
103 }
104
105 void bwd_eval(ValuesMap& v) const
106 {
107 auto y = AnalyticExpr<Y>::value(v);
108
109 std::apply([&v,y](auto &&... x)
110 {
111 C::bwd(y.a, x->value(v).a...);
112 }, this->_x);
113
114 std::apply([&v](auto &&... x)
115 {
116 (x->bwd_eval(v), ...);
117 }, this->_x);
118 }
119
120 virtual std::string str(bool in_parentheses = false) const
121 {
122 std::string s = std::apply([](auto &&... x) {
123 return C::str(x...);
124 }, this->_x);
125 return in_parentheses ? "(" + s + ")" : s;
126 }
127
128 virtual bool is_str_leaf() const
129 {
130 return false;
131 }
132
133 std::pair<Index,Index> output_shape() const
134 {
135 std::pair<Index,Index> s;
136 std::apply([&s](auto &&... x)
137 {
138 s = C::output_shape(x...);
139 }, this->_x);
140 return s;
141 }
142
143 virtual bool belongs_to_args_list(const FunctionArgsList& args) const
144 {
145 bool b = true;
146
147 std::apply([&b,&args](auto &&... x)
148 {
149 ((b &= x->belongs_to_args_list(args)), ...);
150 }, this->_x);
151
152 return b;
153 }
154 };
155}
Abstract base class for representing an expression.
Definition codac2_ExprBase.h:86
const ExprID & unique_id() const
Returns the unique identifier of the expression.
virtual std::shared_ptr< ExprBase > copy() const =0
Creates a copy of the current expression.
A base class for expressions representing operations with multiple operands.
Definition codac2_ExprBase.h:165
OperationExprBase(std::shared_ptr< X >... x)
Definition codac2_ExprBase.h:177
std::tuple< std::shared_ptr< X >... > _x
Definition codac2_ExprBase.h:259
Definition codac2_OctaSym.h:21