-
Notifications
You must be signed in to change notification settings - Fork 0
/
Var.h
124 lines (96 loc) · 2.65 KB
/
Var.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once
#include <memory>
#include "armadillo/armadillo"
namespace netn {
template <typename T> class Model;
struct Component;
class IVectorizable {
public:
virtual ~IVectorizable() = default;
virtual int dimension() const = 0;
Component component(int id) const;
};
struct Component {
Component(const Component & other) : v(other.v), id(other.id) {}
Component(const IVectorizable & v, int id) : v(v), id(id) {}
const IVectorizable & v;
int id;
};
template <typename T>
class Var : public Model<T>, public IVectorizable {
public:
template <typename... Args>
Var(Args... args) : _value(std::make_shared<T>(args...)) {}
Var(const Var & other) : _value(other._value) {}
virtual ~Var() = default;
int dimension() const override { return 1; }
T createEmptyCopy() const;
void setElementOfCopy(int i, T & emptyCopy, double value) const;
T & operator*() { return *_value; }
const T & operator*() const { return *_value; }
const bool operator==(const IVectorizable & other) const;
T eval() const override { return *_value; }
T derivPart(const Component & component) const override;
std::shared_ptr<Model<T>> toModel() const override;
private:
std::shared_ptr<T> _value;
};
typedef Var<arma::mat> Matrix;
typedef Var<double> Scalar;
}
#include "Model.h"
namespace netn {
inline Component IVectorizable::component(int id) const {
return Component(*this, id);
}
template <>
inline int Var<arma::mat>::dimension() const {
return _value->size();
}
template <typename T>
inline T Var<T>::createEmptyCopy() const {
return *_value - *_value;
}
template <typename T>
inline void Var<T>::setElementOfCopy(int i, T & copy, double value) const {
copy = value;
}
template <>
inline void Var<arma::mat>::setElementOfCopy(int i, arma::mat & copy, double value) const {
copy.at(i) = value;
}
template<typename T>
inline const bool Var<T>::operator==(const IVectorizable & other) const {
try {
const Var<T> & var = dynamic_cast<const Var<T> &>(other);
return var._value == _value;
}
catch (const std::bad_cast &) {
return false;
}
}
template <typename T>
inline T Var<T>::derivPart(const Component & component) const {
if (*this == component.v) {
return 1;
}
else {
return 0;
}
}
template <>
inline arma::mat Var<arma::mat>::derivPart(const Component & component) const {
arma::mat zeros(_value->n_rows, _value->n_cols, arma::fill::zeros);
if (*this == component.v) {
zeros.at(component.id) = 1;
return zeros;
}
else {
return zeros;
}
}
template <typename T>
inline std::shared_ptr<Model<T>> Var<T>::toModel() const {
return std::make_shared<Var<T>>(*this);
}
}