-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model.inl
29 lines (24 loc) · 921 Bytes
/
Model.inl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
namespace netn {
template<typename T>
template<typename Var_T>
inline Var_T Model<T>::computeGradient(const Var<Var_T> & var) {
int size = var.dimension();
Var_T gradient = var.createEmptyCopy();
for (int i = 0; i < size; i++) {
T deriv = derivPart({ var, i });
// si T n'est pas convertible en double, la fonction ne compile pas...
// TODO il faudrait une erreur plus explicite à la compilation dans ce cas (du genre "invalid type : matrix")
var.setElementOfCopy(i, gradient, deriv);
}
return gradient;
}
template <typename T>
std::tuple<> computeGradients(const Model<T> & model) {
return std::make_tuple();
}
template<typename T, typename Var_T, typename ...Vars>
std::tuple<Vars...> computeGradients(const Model<T>& model, const Var_T & var, Vars... vars) {
Var_T gradient = computeGradient(model, var);
return std::tuple_cat(gradient, computeGradients(model, vars...));
}
}